diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index ac9a2e75..ff261bad 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -3,7 +3,7 @@ FROM mcr.microsoft.com/vscode/devcontainers/python:0-${VARIANT} USER vscode -RUN curl -sSf https://rye.astral.sh/get | RYE_VERSION="0.35.0" RYE_INSTALL_OPTION="--yes" bash +RUN curl -sSf https://rye.astral.sh/get | RYE_VERSION="0.44.0" RYE_INSTALL_OPTION="--yes" bash ENV PATH=/home/vscode/.rye/shims:$PATH -RUN echo "[[ -d .venv ]] && source .venv/bin/activate" >> /home/vscode/.bashrc +RUN echo "[[ -d .venv ]] && source .venv/bin/activate || export PATH=\$PATH" >> /home/vscode/.bashrc diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index bbeb30b1..c17fdc16 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -24,6 +24,9 @@ } } } + }, + "features": { + "ghcr.io/devcontainers/features/node:1": {} } // Features to add to the dev container. More info: https://containers.dev/features. diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c8a8a4f7..5eaaecd9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,18 +1,23 @@ name: CI on: push: - branches: - - main + branches-ignore: + - 'generated' + - 'codegen/**' + - 'integrated/**' + - 'stl-preview-head/**' + - 'stl-preview-base/**' pull_request: - branches: - - main - - next + branches-ignore: + - 'stl-preview-head/**' + - 'stl-preview-base/**' jobs: lint: + timeout-minutes: 10 name: lint - runs-on: ubuntu-latest - + runs-on: ${{ github.repository == 'stainless-sdks/brainbase-python' && 'depot-ubuntu-24.04' || 'ubuntu-latest' }} + if: github.event_name == 'push' || github.event.pull_request.head.repo.fork steps: - uses: actions/checkout@v4 @@ -21,7 +26,7 @@ jobs: curl -sSf https://rye.astral.sh/get | bash echo "$HOME/.rye/shims" >> $GITHUB_PATH env: - RYE_VERSION: '0.35.0' + RYE_VERSION: '0.44.0' RYE_INSTALL_OPTION: '--yes' - name: Install dependencies @@ -30,10 +35,51 @@ jobs: - name: Run lints run: ./scripts/lint + build: + if: github.event_name == 'push' || github.event.pull_request.head.repo.fork + timeout-minutes: 10 + name: build + permissions: + contents: read + id-token: write + runs-on: ${{ github.repository == 'stainless-sdks/brainbase-python' && 'depot-ubuntu-24.04' || 'ubuntu-latest' }} + steps: + - uses: actions/checkout@v4 + + - name: Install Rye + run: | + curl -sSf https://rye.astral.sh/get | bash + echo "$HOME/.rye/shims" >> $GITHUB_PATH + env: + RYE_VERSION: '0.44.0' + RYE_INSTALL_OPTION: '--yes' + + - name: Install dependencies + run: rye sync --all-features + + - name: Run build + run: rye build + + - name: Get GitHub OIDC Token + if: github.repository == 'stainless-sdks/brainbase-python' + id: github-oidc + uses: actions/github-script@v6 + with: + script: core.setOutput('github_token', await core.getIDToken()); + + - name: Upload tarball + if: github.repository == 'stainless-sdks/brainbase-python' + env: + URL: https://pkg.stainless.com/s + AUTH: ${{ steps.github-oidc.outputs.github_token }} + SHA: ${{ github.sha }} + run: ./scripts/utils/upload-artifact.sh + test: + timeout-minutes: 10 name: test - runs-on: ubuntu-latest - + runs-on: ${{ github.repository == 'stainless-sdks/brainbase-python' && 'depot-ubuntu-24.04' || 'ubuntu-latest' }} + if: github.event_name == 'push' || github.event.pull_request.head.repo.fork steps: - uses: actions/checkout@v4 @@ -42,7 +88,7 @@ jobs: curl -sSf https://rye.astral.sh/get | bash echo "$HOME/.rye/shims" >> $GITHUB_PATH env: - RYE_VERSION: '0.35.0' + RYE_VERSION: '0.44.0' RYE_INSTALL_OPTION: '--yes' - name: Bootstrap diff --git a/.github/workflows/publish-pypi.yml b/.github/workflows/publish-pypi.yml index 3fbc99f8..ec603561 100644 --- a/.github/workflows/publish-pypi.yml +++ b/.github/workflows/publish-pypi.yml @@ -21,7 +21,7 @@ jobs: curl -sSf https://rye.astral.sh/get | bash echo "$HOME/.rye/shims" >> $GITHUB_PATH env: - RYE_VERSION: '0.35.0' + RYE_VERSION: '0.44.0' RYE_INSTALL_OPTION: '--yes' - name: Publish to PyPI diff --git a/.gitignore b/.gitignore index 87797408..95ceb189 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,4 @@ .prism.log -.vscode _dev __pycache__ diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 127ac87b..3b4c2d4b 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "4.0.0" + ".": "4.1.0" } \ No newline at end of file diff --git a/.stats.yml b/.stats.yml index 774ad327..087e4350 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,2 +1,4 @@ configured_endpoints: 15 openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/brainbase-egrigokhan%2Fbrainbase-ab4ce60666d2503f2b7028d55b9f75cc42a76a668cda26576e91b851ea650b0b.yml +openapi_spec_hash: ec07d4f39ed4cb03f93255b680ca2f35 +config_hash: 6f322d88e08375a924b420a5ee9f269c diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..5b010307 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.analysis.importFormat": "relative", +} diff --git a/CHANGELOG.md b/CHANGELOG.md index aab837a5..926a771f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,129 @@ # Changelog +## 4.1.0 (2026-01-06) + +Full Changelog: [v4.0.0...v4.1.0](https://github.com/BrainbaseHQ/brainbase-python-sdk/compare/v4.0.0...v4.1.0) + +### Features + +* **api:** update via SDK Studio ([#61](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/61)) ([9c0c551](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/9c0c551d114f385be676f42aab844de1309b9406)) +* clean up environment call outs ([696f18b](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/696f18b140f81774258c1bf92e36e070fbf65c7e)) +* **client:** add follow_redirects request option ([6eb41c9](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/6eb41c9ae22aeeae7139e4cdc0bb9e0f34bb37ff)) +* **client:** add support for aiohttp ([6f6ddd9](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/6f6ddd94c735b6bcf82920259c45f2a112f12507)) +* **client:** allow passing `NotGiven` for body ([#67](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/67)) ([3ad7f25](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/3ad7f25e58b4823792d6475c7c14a3700273dd13)) +* **client:** send `X-Stainless-Read-Timeout` header ([#63](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/63)) ([a594c75](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/a594c7501a919a038a7b4d08754ee837acd45b89)) +* **client:** support file upload requests ([fde965a](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/fde965a36bf905c2e8ee5e29013141fca52d6e92)) +* improve future compat with pydantic v3 ([bccbddf](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/bccbddf086f9791bcd117701a886c16306c4c4f4)) +* **types:** replace List[str] with SequenceNotStr in params ([0578887](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/0578887c11b95955522042771b4680d5a9f4f6b9)) + + +### Bug Fixes + +* asyncify on non-asyncio runtimes ([#66](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/66)) ([ca310cd](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/ca310cd8beb201d9ad2c66ab13c1e0ed605e6a91)) +* avoid newer type syntax ([db10820](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/db108205529a83cadc51c9506fc2c7a8c372d964)) +* **ci:** correct conditional ([66eb0ef](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/66eb0ef07df1d265cbb31619ec71f2e097c42f4f)) +* **ci:** ensure pip is always available ([#78](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/78)) ([d3d295a](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/d3d295a6a4007e2504000988e6b997a5a96be28b)) +* **ci:** release-doctor — report correct token name ([65cdccf](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/65cdccfa4575fb71d6c4947ba13e210537ec181e)) +* **ci:** remove publishing patch ([#79](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/79)) ([493f504](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/493f504df7150a24d8373065b217c8b69320b432)) +* **client:** close streams without requiring full consumption ([e783145](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/e783145b013ec8af7142e357f9ec4d9d7ddc99b1)) +* **client:** correctly parse binary response | stream ([3924997](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/392499744e9c7f7b6d770f243ae700d772663fad)) +* **client:** don't send Content-Type header on GET requests ([6bcbc4e](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/6bcbc4e13f22a42e2a69d7f62c8148430bbd88d9)) +* **client:** mark some request bodies as optional ([3ad7f25](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/3ad7f25e58b4823792d6475c7c14a3700273dd13)) +* compat with Python 3.14 ([f7d6103](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/f7d6103a91fad3aa7aba7eb7174afe1b2cfea3b9)) +* **compat:** update signatures of `model_dump` and `model_dump_json` for Pydantic v1 ([898ca7b](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/898ca7bd4cf577eeeab764fbcdd9a77518b44587)) +* ensure streams are always closed ([f89af68](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/f89af6811daed337f503b3c3ad4cb93ab8b38936)) +* **package:** support direct resource imports ([ad2d130](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/ad2d130ead5da49f75340d1c642a00854f9d29b6)) +* **parsing:** correctly handle nested discriminated unions ([cd511d2](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/cd511d2d51ded4d55af012cdc475c2aa79be5785)) +* **parsing:** ignore empty metadata ([bdd8ead](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/bdd8eadd20c5e4f34b85c3b8e922831f5f1074ab)) +* **parsing:** parse extra field types ([470d8a8](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/470d8a8851db80f9c3720320a5f51ef214504c52)) +* **perf:** optimize some hot paths ([7cc4937](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/7cc4937c3882b956356579443e99826de72a4025)) +* **perf:** skip traversing types for NotGiven values ([38509ba](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/38509baa0c8b7fe6cd68ee34fe769ae1c0d95a5c)) +* **pydantic v1:** more robust ModelField.annotation check ([3dc3480](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/3dc3480e627d3199117ef6a2b869d413f6408f7b)) +* **tests:** fix: tests which call HTTP endpoints directly with the example parameters ([539215f](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/539215ff08dfbe906fad296b7a09fbd3a2bdb5bf)) +* **types:** allow pyright to infer TypedDict types within SequenceNotStr ([84b4806](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/84b4806cef26c829d8f2cea180d7db87be025788)) +* **types:** handle more discriminated union shapes ([#77](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/77)) ([8b6dcf0](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/8b6dcf01f11bed61a53a66b45ab3ae255deb82e8)) +* use async_to_httpx_files in patch method ([6bfd0c0](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/6bfd0c0e4e37cf063c32a82525914a4e6da16546)) + + +### Chores + +* add Python 3.14 classifier and testing ([c172e9c](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/c172e9c16b4ca6fd3c03631a9f58c85c6132de7d)) +* broadly detect json family of content-type headers ([febefbc](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/febefbc87ca1d2bbfadc9932400b5ae42644f49a)) +* bump `httpx-aiohttp` version to 0.1.9 ([7aeb4c8](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/7aeb4c87f034ee7f506a64adf52a349efb343cf5)) +* **ci:** add timeout thresholds for CI jobs ([d5cbcd0](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/d5cbcd06c3d27b71601e34e9ea8fb2a5461ec37d)) +* **ci:** change upload type ([d7e4405](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/d7e4405ad6aa676d5f21fe603b705f2bf7d36496)) +* **ci:** enable for pull requests ([1ea6fbc](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/1ea6fbc396ca2c40b1233e7e07ab81b6f1b19620)) +* **ci:** fix installation instructions ([6291f4a](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/6291f4ad4b3f4ae3624f3ef3dafc3ef5c9ae4ced)) +* **ci:** only run for pushes and fork pull requests ([5d00f3e](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/5d00f3ed62f35deb17229244613a788ad45bbdc7)) +* **ci:** only use depot for staging repos ([19ee773](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/19ee77377796949bd043a65b8e6934a60a9aa455)) +* **ci:** upload sdks to package manager ([598ec7e](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/598ec7e8bf03284b8db9e94049d5e42a1adef26b)) +* **client:** minor internal fixes ([1e29d3b](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/1e29d3b13de115b6047b3a712ceeac69f77ba51f)) +* **deps:** mypy 1.18.1 has a regression, pin to 1.17 ([fd940b6](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/fd940b62669bcb40a24c2be17e5ab75d116c4ce2)) +* do not install brew dependencies in ./scripts/bootstrap by default ([67c48f0](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/67c48f09376cd85fbb5dca63dc47d83c106e7be6)) +* **docs:** grammar improvements ([540e711](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/540e711335d2757eb6c76138bb5cc8fceae04ff3)) +* **docs:** remove reference to rye shell ([ec32daa](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/ec32daa0c329830982eabdffde4356b3dafdfb41)) +* **docs:** update client docstring ([#71](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/71)) ([b41543a](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/b41543a89e16e504054f95ef8987e12bef2da3a2)) +* **docs:** use environment variables for authentication in code snippets ([ff7f0ef](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/ff7f0ef0716e3bbe9973b18a4a69b5608376f701)) +* fix typos ([#80](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/80)) ([c1576cc](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/c1576ccb8e5592b395e435cc949ba7c077211b67)) +* **internal/tests:** avoid race condition with implicit client cleanup ([d3a5435](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/d3a5435f4e7c9609de5141ec44639b109758a49b)) +* **internal:** add `--fix` argument to lint script ([0b1e68e](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/0b1e68e074af9dc9a8e60a539c964433c66c3887)) +* **internal:** add missing files argument to base client ([5547d1b](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/5547d1ba450d6ccb13963d24699f89f487e7dc78)) +* **internal:** add Sequence related utils ([5f815ef](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/5f815ef63af5a60d77c4f792f3aaf42b317be58f)) +* **internal:** avoid errors for isinstance checks on proxies ([5dc0949](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/5dc09490c13d71ac41d430d78e9c32f8f59c330e)) +* **internal:** base client updates ([0ac179a](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/0ac179a512ed4773d3717a4226c7f041c67eab17)) +* **internal:** bump pinned h11 dep ([6cafe07](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/6cafe072603e69522dd754cf2d3f2a59b8f49271)) +* **internal:** bump pyright version ([81c2baf](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/81c2bafbaab324b6e4c85c1e193148fbe2c88525)) +* **internal:** bump rye to 0.44.0 ([#76](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/76)) ([21a20b3](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/21a20b383e1ce4b072bdbcdefc5774f9e2ba21f4)) +* **internal:** change ci workflow machines ([656643d](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/656643d1bd5e3bba6a6997ac3ecbce69d0aaf4cc)) +* **internal:** codegen related update ([0b4efb7](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/0b4efb79489847366f43e83b5be51c68eddac2bd)) +* **internal:** codegen related update ([2f8fbc4](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/2f8fbc4b1ebdcf56188bb58f763e2fb9fd8c202b)) +* **internal:** codegen related update ([d6d5a1d](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/d6d5a1d5cf3fec7ee383f25b8927d773a20230ca)) +* **internal:** codegen related update ([a145cee](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/a145cee50fe2c30203c1beea85b9e87e58339103)) +* **internal:** codegen related update ([#75](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/75)) ([db19786](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/db197864c745b7235e546f349e0a24a7c8cd9801)) +* **internal:** detect missing future annotations with ruff ([23b94ed](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/23b94edc4b2f2e1ab1b8ed02881296a185b3ccac)) +* **internal:** expand CI branch coverage ([7fd1145](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/7fd1145852e5f82cd271dcb4432e2474e1cbd9d4)) +* **internal:** fix devcontainers setup ([#68](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/68)) ([97b7254](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/97b725436eaca395261fa04119202ccd938e2edd)) +* **internal:** fix list file params ([1b5e333](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/1b5e333c271ee804b2fe7738dfb2ee7bd0044c9a)) +* **internal:** fix ruff target version ([1437b86](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/1437b86f88760d95f63a3cd9e5e4e3d1b9cc1673)) +* **internal:** fix type traversing dictionary params ([#64](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/64)) ([1322c80](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/1322c808afea9e8d17ac958e289850115c8d0fe8)) +* **internal:** grammar fix (it's -> its) ([8b1cdb7](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/8b1cdb737ecd09456400b79a03d1aa2d02ab8e2b)) +* **internal:** import reformatting ([8a3f6f0](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/8a3f6f0fc98525d48998f38b6bafc6b78bbea73d)) +* **internal:** minor type handling changes ([#65](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/65)) ([7e69125](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/7e691251edd799b8ffd067792c0833f99a6906bb)) +* **internal:** move mypy configurations to `pyproject.toml` file ([f87b268](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/f87b2684a44a8c6112fa30fd935e42d9d532fdbf)) +* **internal:** properly set __pydantic_private__ ([#69](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/69)) ([bc25b84](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/bc25b84c059996c26d435aa61da24f684b25f2e8)) +* **internal:** reduce CI branch coverage ([2492996](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/249299619c40b6abeef1332f3c3e33e436fe5da5)) +* **internal:** refactor retries to not use recursion ([055e329](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/055e329cdf3927632ab1e4b128c2ed5ce0333e8f)) +* **internal:** remove extra empty newlines ([#74](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/74)) ([3d90dff](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/3d90dff0832619c6495bbf7f82cc74ac9aac46c9)) +* **internal:** remove trailing character ([#81](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/81)) ([4cfa80b](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/4cfa80bf1f6a9f24ae0a7244b3ae5a248e131edc)) +* **internal:** remove unused http client options forwarding ([#72](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/72)) ([69a44e3](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/69a44e38bcfb88f523e036d4f434942a96397061)) +* **internal:** slight transform perf improvement ([#82](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/82)) ([5498eaf](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/5498eaf9154b388668be1a516d3597d27b74cc7c)) +* **internal:** update comment in script ([103820e](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/103820ebecd7a1dd118282d8becf5db502a1ec0d)) +* **internal:** update conftest.py ([83531b4](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/83531b46d24f38f1f7c6a3ca13ddbb4e5d16590b)) +* **internal:** update models test ([421a2b5](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/421a2b57d0973344ae3d86ec0f989ca655490970)) +* **internal:** update pydantic dependency ([5dd09a4](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/5dd09a48ae3a7ef8efe86304ea1d84acd0bf7d39)) +* **internal:** update pyright exclude list ([f306088](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/f3060880c80762252f85cb071eae6b7c1263265f)) +* **internal:** update pyright settings ([2d267e1](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/2d267e1d9a8fee34b0dde0ca9274afac975a1648)) +* **package:** drop Python 3.8 support ([95ca18d](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/95ca18d4426eac5714e12b54f5f0fb181aaa85e6)) +* **package:** mark python 3.13 as supported ([06ad64f](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/06ad64facff9ca27b1049dc9bc2a41aa1872b7ee)) +* **project:** add settings file for vscode ([6754b39](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/6754b3937cfa144183e8d4633efdd19e2d0e1413)) +* **readme:** fix version rendering on pypi ([b1e9e51](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/b1e9e51bb3adf1c0f2387b4aff1804ff6e40e346)) +* **readme:** update badges ([f3f214f](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/f3f214f315d721421f383ba4b5e63ddd828c0b59)) +* speedup initial import ([3b3f4e6](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/3b3f4e63e4488509e60062a71cf5f3f569b4e529)) +* **tests:** add tests for httpx client instantiation & proxies ([5e172cd](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/5e172cd1be541e04ab62ef1c770b5d1e9b7fc44d)) +* **tests:** run tests in parallel ([497b381](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/497b3816f4167e344dbbad4a2910236436641777)) +* **tests:** simplify `get_platform` test ([ef07d85](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/ef07d85cf3d9001104e57e9e1a266aa65f6852e5)) +* **tests:** skip some failing tests on the latest python versions ([39d037c](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/39d037c9713c26103a549a3b0c03003a93ef3fc4)) +* **types:** change optional parameter type from NotGiven to Omit ([14cb9a9](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/14cb9a994e36c901e9b73ca1004a40cb30f87786)) +* update @stainless-api/prism-cli to v5.15.0 ([d766a01](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/d766a015e4fd9a56336bf37be4cce0eb08361a20)) +* update github action ([d0f9d9e](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/d0f9d9e2bd470f7ee9f2c29cf514d9a9da093c0b)) +* update lockfile ([77726a0](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/77726a0c2aeb678ef7995da7be5aa667388917ca)) + + +### Documentation + +* **client:** fix httpx.Timeout documentation reference ([3341b85](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/3341b85797cfcc35e3e8c1f3a151176c85587ec8)) +* update URLs from stainlessapi.com to stainless.com ([#70](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/70)) ([08062c4](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/08062c48906749ce529d5ea1fcdfed3b1b18412c)) + ## 4.0.0 (2025-02-04) Full Changelog: [v3.0.0...v4.0.0](https://github.com/BrainbaseHQ/brainbase-python-sdk/compare/v3.0.0...v4.0.0) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 2cf8f44a..37d38185 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -17,8 +17,7 @@ $ rye sync --all-features You can then run scripts using `rye run python script.py` or by activating the virtual environment: ```sh -$ rye shell -# or manually activate - https://docs.python.org/3/library/venv.html#how-venvs-work +# Activate the virtual environment - https://docs.python.org/3/library/venv.html#how-venvs-work $ source .venv/bin/activate # now you can omit the `rye run` prefix diff --git a/LICENSE b/LICENSE index 62446133..2ab165fa 100644 --- a/LICENSE +++ b/LICENSE @@ -186,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright 2025 Brainbase + Copyright 2026 Brainbase Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/README.md b/README.md index 33ff2fa4..33243d9c 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,17 @@ # Brainbase Python API library -[![PyPI version](https://img.shields.io/pypi/v/brainbase-labs.svg)](https://pypi.org/project/brainbase-labs/) + +[![PyPI version](https://img.shields.io/pypi/v/brainbase-labs.svg?label=pypi%20(stable))](https://pypi.org/project/brainbase-labs/) -The Brainbase Python library provides convenient access to the Brainbase REST API from any Python 3.8+ +The Brainbase Python library provides convenient access to the Brainbase REST API from any Python 3.9+ application. The library includes type definitions for all request params and response fields, and offers both synchronous and asynchronous clients powered by [httpx](https://github.com/encode/httpx). -It is generated with [Stainless](https://www.stainlessapi.com/). +It is generated with [Stainless](https://www.stainless.com/). ## Documentation -The REST API documentation can be found on [docs.usebrainbase.xyz](https://docs.usebrainbase.xyz). The full API of this library can be found in [api.md](api.md). +The REST API documentation can be found on [docs.usebrainbase.com](https://docs.usebrainbase.com). The full API of this library can be found in [api.md](api.md). ## Installation @@ -62,6 +63,37 @@ asyncio.run(main()) Functionality between the synchronous and asynchronous clients is otherwise identical. +### With aiohttp + +By default, the async client uses `httpx` for HTTP requests. However, for improved concurrency performance you may also use `aiohttp` as the HTTP backend. + +You can enable this by installing `aiohttp`: + +```sh +# install from PyPI +pip install brainbase-labs[aiohttp] +``` + +Then you can enable it by instantiating the client with `http_client=DefaultAioHttpClient()`: + +```python +import os +import asyncio +from brainbase import DefaultAioHttpClient +from brainbase import AsyncBrainbase + + +async def main() -> None: + async with AsyncBrainbase( + api_key=os.environ.get("API_KEY"), # This is the default and can be omitted + http_client=DefaultAioHttpClient(), + ) as client: + workers = await client.workers.list() + + +asyncio.run(main()) +``` + ## Using types Nested request parameters are [TypedDicts](https://docs.python.org/3/library/typing.html#typing.TypedDict). Responses are [Pydantic models](https://docs.pydantic.dev) which also provide helper methods for things like: @@ -136,7 +168,7 @@ client.with_options(max_retries=5).workers.list() ### Timeouts By default requests time out after 1 minute. You can configure this with a `timeout` option, -which accepts a float or an [`httpx.Timeout`](https://www.python-httpx.org/advanced/#fine-tuning-the-configuration) object: +which accepts a float or an [`httpx.Timeout`](https://www.python-httpx.org/advanced/timeouts/#fine-tuning-the-configuration) object: ```python from brainbase import Brainbase @@ -322,7 +354,7 @@ print(brainbase.__version__) ## Requirements -Python 3.8 or higher. +Python 3.9 or higher. ## Contributing diff --git a/SECURITY.md b/SECURITY.md index 752bd79a..f5c1b88f 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -2,9 +2,9 @@ ## Reporting Security Issues -This SDK is generated by [Stainless Software Inc](http://stainlessapi.com). Stainless takes security seriously, and encourages you to report any security vulnerability promptly so that appropriate action can be taken. +This SDK is generated by [Stainless Software Inc](http://stainless.com). Stainless takes security seriously, and encourages you to report any security vulnerability promptly so that appropriate action can be taken. -To report a security issue, please contact the Stainless team at security@stainlessapi.com. +To report a security issue, please contact the Stainless team at security@stainless.com. ## Responsible Disclosure @@ -16,11 +16,11 @@ before making any information public. ## Reporting Non-SDK Related Security Issues If you encounter security issues that are not directly related to SDKs but pertain to the services -or products provided by Brainbase please follow the respective company's security reporting guidelines. +or products provided by Brainbase, please follow the respective company's security reporting guidelines. ### Brainbase Terms and Policies -Please contact dev@brainbase.com for any questions or concerns regarding security of our services. +Please contact dev-feedback@brainbase.com for any questions or concerns regarding the security of our services. --- diff --git a/api.md b/api.md index 2022bf88..ecd4ff83 100644 --- a/api.md +++ b/api.md @@ -26,19 +26,14 @@ Methods: Types: ```python -from brainbase.types.workers.deployments import ( - VoiceCreateResponse, - VoiceRetrieveResponse, - VoiceUpdateResponse, - VoiceListResponse, -) +from brainbase.types.workers.deployments import VoiceDeployment, VoiceListResponse ``` Methods: -- client.workers.deployments.voice.create(worker_id, \*\*params) -> VoiceCreateResponse -- client.workers.deployments.voice.retrieve(deployment_id, \*, worker_id) -> VoiceRetrieveResponse -- client.workers.deployments.voice.update(deployment_id, \*, worker_id, \*\*params) -> VoiceUpdateResponse +- client.workers.deployments.voice.create(worker_id, \*\*params) -> VoiceDeployment +- client.workers.deployments.voice.retrieve(deployment_id, \*, worker_id) -> VoiceDeployment +- client.workers.deployments.voice.update(deployment_id, \*, worker_id, \*\*params) -> VoiceDeployment - client.workers.deployments.voice.list(worker_id) -> VoiceListResponse - client.workers.deployments.voice.delete(deployment_id, \*, worker_id) -> None diff --git a/bin/check-release-environment b/bin/check-release-environment index 59422a48..b845b0f4 100644 --- a/bin/check-release-environment +++ b/bin/check-release-environment @@ -3,7 +3,7 @@ errors=() if [ -z "${PYPI_TOKEN}" ]; then - errors+=("The BRAINBASE_PYPI_TOKEN secret has not been set. Please set it in either this repository's secrets or your organization secrets.") + errors+=("The PYPI_TOKEN secret has not been set. Please set it in either this repository's secrets or your organization secrets.") fi lenErrors=${#errors[@]} diff --git a/bin/publish-pypi b/bin/publish-pypi index 05bfccbb..826054e9 100644 --- a/bin/publish-pypi +++ b/bin/publish-pypi @@ -3,7 +3,4 @@ set -eux mkdir -p dist rye build --clean -# Patching importlib-metadata version until upstream library version is updated -# https://github.com/pypa/twine/issues/977#issuecomment-2189800841 -"$HOME/.rye/self/bin/python3" -m pip install 'importlib-metadata==7.2.1' rye publish --yes --token=$PYPI_TOKEN diff --git a/mypy.ini b/mypy.ini deleted file mode 100644 index e062bd87..00000000 --- a/mypy.ini +++ /dev/null @@ -1,50 +0,0 @@ -[mypy] -pretty = True -show_error_codes = True - -# Exclude _files.py because mypy isn't smart enough to apply -# the correct type narrowing and as this is an internal module -# it's fine to just use Pyright. -# -# We also exclude our `tests` as mypy doesn't always infer -# types correctly and Pyright will still catch any type errors. -exclude = ^(src/brainbase/_files\.py|_dev/.*\.py|tests/.*)$ - -strict_equality = True -implicit_reexport = True -check_untyped_defs = True -no_implicit_optional = True - -warn_return_any = True -warn_unreachable = True -warn_unused_configs = True - -# Turn these options off as it could cause conflicts -# with the Pyright options. -warn_unused_ignores = False -warn_redundant_casts = False - -disallow_any_generics = True -disallow_untyped_defs = True -disallow_untyped_calls = True -disallow_subclassing_any = True -disallow_incomplete_defs = True -disallow_untyped_decorators = True -cache_fine_grained = True - -# By default, mypy reports an error if you assign a value to the result -# of a function call that doesn't return anything. We do this in our test -# cases: -# ``` -# result = ... -# assert result is None -# ``` -# Changing this codegen to make mypy happy would increase complexity -# and would not be worth it. -disable_error_code = func-returns-value,overload-cannot-match - -# https://github.com/python/mypy/issues/12162 -[mypy.overrides] -module = "black.files.*" -ignore_errors = true -ignore_missing_imports = true diff --git a/pyproject.toml b/pyproject.toml index ac778a61..300483ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,29 +1,32 @@ [project] name = "brainbase-labs" -version = "4.0.0" +version = "4.1.0" description = "The official Python library for the brainbase API" dynamic = ["readme"] license = "Apache-2.0" authors = [ -{ name = "Brainbase", email = "dev@brainbase.com" }, +{ name = "Brainbase", email = "dev-feedback@brainbase.com" }, ] + dependencies = [ - "httpx>=0.23.0, <1", - "pydantic>=1.9.0, <3", - "typing-extensions>=4.10, <5", - "anyio>=3.5.0, <5", - "distro>=1.7.0, <2", - "sniffio", + "httpx>=0.23.0, <1", + "pydantic>=1.9.0, <3", + "typing-extensions>=4.10, <5", + "anyio>=3.5.0, <5", + "distro>=1.7.0, <2", + "sniffio", ] -requires-python = ">= 3.8" + +requires-python = ">= 3.9" classifiers = [ "Typing :: Typed", "Intended Audience :: Developers", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", "Operating System :: OS Independent", "Operating System :: POSIX", "Operating System :: MacOS", @@ -37,14 +40,15 @@ classifiers = [ Homepage = "https://github.com/BrainbaseHQ/brainbase-python-sdk" Repository = "https://github.com/BrainbaseHQ/brainbase-python-sdk" - +[project.optional-dependencies] +aiohttp = ["aiohttp", "httpx_aiohttp>=0.1.9"] [tool.rye] managed = true # version pins are in requirements-dev.lock dev-dependencies = [ - "pyright>=1.1.359", - "mypy", + "pyright==1.1.399", + "mypy==1.17", "respx", "pytest", "pytest-asyncio", @@ -54,7 +58,7 @@ dev-dependencies = [ "dirty-equals>=0.6.0", "importlib-metadata>=6.7.0", "rich>=13.7.1", - "nest_asyncio==1.6.0", + "pytest-xdist>=3.6.1", ] [tool.rye.scripts] @@ -87,7 +91,7 @@ typecheck = { chain = [ "typecheck:mypy" = "mypy ." [build-system] -requires = ["hatchling", "hatch-fancy-pypi-readme"] +requires = ["hatchling==1.26.3", "hatch-fancy-pypi-readme"] build-backend = "hatchling.build" [tool.hatch.build] @@ -126,7 +130,7 @@ replacement = '[\1](https://github.com/BrainbaseHQ/brainbase-python-sdk/tree/mai [tool.pytest.ini_options] testpaths = ["tests"] -addopts = "--tb=short" +addopts = "--tb=short -n auto" xfail_strict = true asyncio_mode = "auto" asyncio_default_fixture_loop_scope = "session" @@ -139,24 +143,77 @@ filterwarnings = [ # there are a couple of flags that are still disabled by # default in strict mode as they are experimental and niche. typeCheckingMode = "strict" -pythonVersion = "3.8" +pythonVersion = "3.9" exclude = [ "_dev", ".venv", ".nox", + ".git", ] reportImplicitOverride = true +reportOverlappingOverload = false reportImportCycles = false reportPrivateUsage = false +[tool.mypy] +pretty = true +show_error_codes = true + +# Exclude _files.py because mypy isn't smart enough to apply +# the correct type narrowing and as this is an internal module +# it's fine to just use Pyright. +# +# We also exclude our `tests` as mypy doesn't always infer +# types correctly and Pyright will still catch any type errors. +exclude = ['src/brainbase/_files.py', '_dev/.*.py', 'tests/.*'] + +strict_equality = true +implicit_reexport = true +check_untyped_defs = true +no_implicit_optional = true + +warn_return_any = true +warn_unreachable = true +warn_unused_configs = true + +# Turn these options off as it could cause conflicts +# with the Pyright options. +warn_unused_ignores = false +warn_redundant_casts = false + +disallow_any_generics = true +disallow_untyped_defs = true +disallow_untyped_calls = true +disallow_subclassing_any = true +disallow_incomplete_defs = true +disallow_untyped_decorators = true +cache_fine_grained = true + +# By default, mypy reports an error if you assign a value to the result +# of a function call that doesn't return anything. We do this in our test +# cases: +# ``` +# result = ... +# assert result is None +# ``` +# Changing this codegen to make mypy happy would increase complexity +# and would not be worth it. +disable_error_code = "func-returns-value,overload-cannot-match" + +# https://github.com/python/mypy/issues/12162 +[[tool.mypy.overrides]] +module = "black.files.*" +ignore_errors = true +ignore_missing_imports = true + [tool.ruff] line-length = 120 output-format = "grouped" -target-version = "py37" +target-version = "py38" [tool.ruff.format] docstring-code-format = true @@ -169,6 +226,8 @@ select = [ "B", # remove unused imports "F401", + # check for missing future annotations + "FA102", # bare except statements "E722", # unused arguments @@ -191,6 +250,8 @@ unfixable = [ "T203", ] +extend-safe-fixes = ["FA102"] + [tool.ruff.lint.flake8-tidy-imports.banned-api] "functools.lru_cache".msg = "This function does not retain type information for the wrapped function's arguments; The `lru_cache` function from `_utils` should be used instead" diff --git a/requirements-dev.lock b/requirements-dev.lock index 2ccbd89c..8a56d82f 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -7,97 +7,143 @@ # all-features: true # with-sources: false # generate-hashes: false +# universal: false -e file:. -annotated-types==0.6.0 +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.13.2 + # via brainbase-labs + # via httpx-aiohttp +aiosignal==1.4.0 + # via aiohttp +annotated-types==0.7.0 # via pydantic -anyio==4.4.0 +anyio==4.12.0 # via brainbase-labs # via httpx -argcomplete==3.1.2 +argcomplete==3.6.3 + # via nox +async-timeout==5.0.1 + # via aiohttp +attrs==25.4.0 + # via aiohttp # via nox -certifi==2023.7.22 +backports-asyncio-runner==1.2.0 + # via pytest-asyncio +certifi==2025.11.12 # via httpcore # via httpx -colorlog==6.7.0 +colorlog==6.10.1 + # via nox +dependency-groups==1.3.1 # via nox -dirty-equals==0.6.0 -distlib==0.3.7 +dirty-equals==0.11 +distlib==0.4.0 # via virtualenv -distro==1.8.0 +distro==1.9.0 # via brainbase-labs -exceptiongroup==1.2.2 +exceptiongroup==1.3.1 # via anyio # via pytest -filelock==3.12.4 +execnet==2.1.2 + # via pytest-xdist +filelock==3.19.1 # via virtualenv -h11==0.14.0 +frozenlist==1.8.0 + # via aiohttp + # via aiosignal +h11==0.16.0 # via httpcore -httpcore==1.0.2 +httpcore==1.0.9 # via httpx httpx==0.28.1 # via brainbase-labs + # via httpx-aiohttp # via respx -idna==3.4 +httpx-aiohttp==0.1.9 + # via brainbase-labs +humanize==4.13.0 + # via nox +idna==3.11 # via anyio # via httpx -importlib-metadata==7.0.0 -iniconfig==2.0.0 + # via yarl +importlib-metadata==8.7.0 +iniconfig==2.1.0 # via pytest markdown-it-py==3.0.0 # via rich mdurl==0.1.2 # via markdown-it-py -mypy==1.14.1 -mypy-extensions==1.0.0 +multidict==6.7.0 + # via aiohttp + # via yarl +mypy==1.17.0 +mypy-extensions==1.1.0 # via mypy -nest-asyncio==1.6.0 -nodeenv==1.8.0 +nodeenv==1.9.1 # via pyright -nox==2023.4.22 -packaging==23.2 +nox==2025.11.12 +packaging==25.0 + # via dependency-groups # via nox # via pytest -platformdirs==3.11.0 +pathspec==0.12.1 + # via mypy +platformdirs==4.4.0 # via virtualenv -pluggy==1.5.0 +pluggy==1.6.0 # via pytest -pydantic==2.10.3 +propcache==0.4.1 + # via aiohttp + # via yarl +pydantic==2.12.5 # via brainbase-labs -pydantic-core==2.27.1 +pydantic-core==2.41.5 # via pydantic -pygments==2.18.0 +pygments==2.19.2 + # via pytest # via rich -pyright==1.1.392.post0 -pytest==8.3.3 +pyright==1.1.399 +pytest==8.4.2 # via pytest-asyncio -pytest-asyncio==0.24.0 -python-dateutil==2.8.2 + # via pytest-xdist +pytest-asyncio==1.2.0 +pytest-xdist==3.8.0 +python-dateutil==2.9.0.post0 # via time-machine -pytz==2023.3.post1 - # via dirty-equals respx==0.22.0 -rich==13.7.1 -ruff==0.9.4 -setuptools==68.2.2 - # via nodeenv -six==1.16.0 +rich==14.2.0 +ruff==0.14.7 +six==1.17.0 # via python-dateutil -sniffio==1.3.0 - # via anyio +sniffio==1.3.1 # via brainbase-labs -time-machine==2.9.0 -tomli==2.0.2 +time-machine==2.19.0 +tomli==2.3.0 + # via dependency-groups # via mypy + # via nox # via pytest -typing-extensions==4.12.2 +typing-extensions==4.15.0 + # via aiosignal # via anyio # via brainbase-labs + # via exceptiongroup + # via multidict # via mypy # via pydantic # via pydantic-core # via pyright -virtualenv==20.24.5 + # via pytest-asyncio + # via typing-inspection + # via virtualenv +typing-inspection==0.4.2 + # via pydantic +virtualenv==20.35.4 # via nox -zipp==3.17.0 +yarl==1.22.0 + # via aiohttp +zipp==3.23.0 # via importlib-metadata diff --git a/requirements.lock b/requirements.lock index ab585024..46d844a3 100644 --- a/requirements.lock +++ b/requirements.lock @@ -7,38 +7,70 @@ # all-features: true # with-sources: false # generate-hashes: false +# universal: false -e file:. -annotated-types==0.6.0 +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.13.2 + # via brainbase-labs + # via httpx-aiohttp +aiosignal==1.4.0 + # via aiohttp +annotated-types==0.7.0 # via pydantic -anyio==4.4.0 +anyio==4.12.0 # via brainbase-labs # via httpx -certifi==2023.7.22 +async-timeout==5.0.1 + # via aiohttp +attrs==25.4.0 + # via aiohttp +certifi==2025.11.12 # via httpcore # via httpx -distro==1.8.0 +distro==1.9.0 # via brainbase-labs -exceptiongroup==1.2.2 +exceptiongroup==1.3.1 # via anyio -h11==0.14.0 +frozenlist==1.8.0 + # via aiohttp + # via aiosignal +h11==0.16.0 # via httpcore -httpcore==1.0.2 +httpcore==1.0.9 # via httpx httpx==0.28.1 # via brainbase-labs -idna==3.4 + # via httpx-aiohttp +httpx-aiohttp==0.1.9 + # via brainbase-labs +idna==3.11 # via anyio # via httpx -pydantic==2.10.3 + # via yarl +multidict==6.7.0 + # via aiohttp + # via yarl +propcache==0.4.1 + # via aiohttp + # via yarl +pydantic==2.12.5 # via brainbase-labs -pydantic-core==2.27.1 +pydantic-core==2.41.5 # via pydantic -sniffio==1.3.0 - # via anyio +sniffio==1.3.1 # via brainbase-labs -typing-extensions==4.12.2 +typing-extensions==4.15.0 + # via aiosignal # via anyio # via brainbase-labs + # via exceptiongroup + # via multidict # via pydantic # via pydantic-core + # via typing-inspection +typing-inspection==0.4.2 + # via pydantic +yarl==1.22.0 + # via aiohttp diff --git a/scripts/bootstrap b/scripts/bootstrap index e84fe62c..b430fee3 100755 --- a/scripts/bootstrap +++ b/scripts/bootstrap @@ -4,10 +4,18 @@ set -e cd "$(dirname "$0")/.." -if ! command -v rye >/dev/null 2>&1 && [ -f "Brewfile" ] && [ "$(uname -s)" = "Darwin" ]; then +if [ -f "Brewfile" ] && [ "$(uname -s)" = "Darwin" ] && [ "$SKIP_BREW" != "1" ] && [ -t 0 ]; then brew bundle check >/dev/null 2>&1 || { - echo "==> Installing Homebrew dependencies…" - brew bundle + echo -n "==> Install Homebrew dependencies? (y/N): " + read -r response + case "$response" in + [yY][eE][sS]|[yY]) + brew bundle + ;; + *) + ;; + esac + echo } fi diff --git a/scripts/lint b/scripts/lint index fa0c0399..83623cb6 100755 --- a/scripts/lint +++ b/scripts/lint @@ -4,8 +4,13 @@ set -e cd "$(dirname "$0")/.." -echo "==> Running lints" -rye run lint +if [ "$1" = "--fix" ]; then + echo "==> Running lints with --fix" + rye run fix:ruff +else + echo "==> Running lints" + rye run lint +fi echo "==> Making sure it imports" rye run python -c 'import brainbase' diff --git a/scripts/mock b/scripts/mock index d2814ae6..0b28f6ea 100755 --- a/scripts/mock +++ b/scripts/mock @@ -21,7 +21,7 @@ echo "==> Starting mock server with URL ${URL}" # Run prism mock on the given spec if [ "$1" == "--daemon" ]; then - npm exec --package=@stainless-api/prism-cli@5.8.5 -- prism mock "$URL" &> .prism.log & + npm exec --package=@stainless-api/prism-cli@5.15.0 -- prism mock "$URL" &> .prism.log & # Wait for server to come online echo -n "Waiting for server" @@ -37,5 +37,5 @@ if [ "$1" == "--daemon" ]; then echo else - npm exec --package=@stainless-api/prism-cli@5.8.5 -- prism mock "$URL" + npm exec --package=@stainless-api/prism-cli@5.15.0 -- prism mock "$URL" fi diff --git a/scripts/test b/scripts/test index 4fa5698b..dbeda2d2 100755 --- a/scripts/test +++ b/scripts/test @@ -43,7 +43,7 @@ elif ! prism_is_running ; then echo -e "To run the server, pass in the path or url of your OpenAPI" echo -e "spec to the prism command:" echo - echo -e " \$ ${YELLOW}npm exec --package=@stoplight/prism-cli@~5.3.2 -- prism mock path/to/your.openapi.yml${NC}" + echo -e " \$ ${YELLOW}npm exec --package=@stainless-api/prism-cli@5.15.0 -- prism mock path/to/your.openapi.yml${NC}" echo exit 1 @@ -52,6 +52,8 @@ else echo fi +export DEFER_PYDANTIC_BUILD=false + echo "==> Running tests" rye run pytest "$@" diff --git a/scripts/utils/upload-artifact.sh b/scripts/utils/upload-artifact.sh new file mode 100755 index 00000000..64d4441e --- /dev/null +++ b/scripts/utils/upload-artifact.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash +set -exuo pipefail + +FILENAME=$(basename dist/*.whl) + +RESPONSE=$(curl -X POST "$URL?filename=$FILENAME" \ + -H "Authorization: Bearer $AUTH" \ + -H "Content-Type: application/json") + +SIGNED_URL=$(echo "$RESPONSE" | jq -r '.url') + +if [[ "$SIGNED_URL" == "null" ]]; then + echo -e "\033[31mFailed to get signed URL.\033[0m" + exit 1 +fi + +UPLOAD_RESPONSE=$(curl -v -X PUT \ + -H "Content-Type: binary/octet-stream" \ + --data-binary "@dist/$FILENAME" "$SIGNED_URL" 2>&1) + +if echo "$UPLOAD_RESPONSE" | grep -q "HTTP/[0-9.]* 200"; then + echo -e "\033[32mUploaded build to Stainless storage.\033[0m" + echo -e "\033[32mInstallation: pip install 'https://pkg.stainless.com/s/brainbase-python/$SHA/$FILENAME'\033[0m" +else + echo -e "\033[31mFailed to upload artifact.\033[0m" + exit 1 +fi diff --git a/src/brainbase/__init__.py b/src/brainbase/__init__.py index 9f3a57b1..0d991ee6 100644 --- a/src/brainbase/__init__.py +++ b/src/brainbase/__init__.py @@ -1,7 +1,9 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. +import typing as _t + from . import types -from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, Transport, ProxiesTypes +from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, Transport, ProxiesTypes, omit, not_given from ._utils import file_from_path from ._client import ( Client, @@ -34,7 +36,7 @@ UnprocessableEntityError, APIResponseValidationError, ) -from ._base_client import DefaultHttpxClient, DefaultAsyncHttpxClient +from ._base_client import DefaultHttpxClient, DefaultAioHttpClient, DefaultAsyncHttpxClient from ._utils._logs import setup_logging as _setup_logging __all__ = [ @@ -46,7 +48,9 @@ "ProxiesTypes", "NotGiven", "NOT_GIVEN", + "not_given", "Omit", + "omit", "BrainbaseError", "APIError", "APIStatusError", @@ -76,8 +80,12 @@ "DEFAULT_CONNECTION_LIMITS", "DefaultHttpxClient", "DefaultAsyncHttpxClient", + "DefaultAioHttpClient", ] +if not _t.TYPE_CHECKING: + from ._utils._resources_proxy import resources as resources + _setup_logging() # Update the __module__ attribute for exported symbols so that diff --git a/src/brainbase/_base_client.py b/src/brainbase/_base_client.py index 9bb716ae..7849a30a 100644 --- a/src/brainbase/_base_client.py +++ b/src/brainbase/_base_client.py @@ -9,7 +9,6 @@ import inspect import logging import platform -import warnings import email.utils from types import TracebackType from random import random @@ -36,14 +35,13 @@ import httpx import distro import pydantic -from httpx import URL, Limits +from httpx import URL from pydantic import PrivateAttr from . import _exceptions from ._qs import Querystring from ._files import to_httpx_files, async_to_httpx_files from ._types import ( - NOT_GIVEN, Body, Omit, Query, @@ -51,19 +49,17 @@ Timeout, NotGiven, ResponseT, - Transport, AnyMapping, PostParser, - ProxiesTypes, RequestFiles, HttpxSendArgs, - AsyncTransport, RequestOptions, HttpxRequestFiles, ModelBuilderProtocol, + not_given, ) from ._utils import is_dict, is_list, asyncify, is_given, lru_cache, is_mapping -from ._compat import model_copy, model_dump +from ._compat import PYDANTIC_V1, model_copy, model_dump from ._models import GenericModel, FinalRequestOptions, validate_type, construct_type from ._response import ( APIResponse, @@ -102,7 +98,11 @@ _AsyncStreamT = TypeVar("_AsyncStreamT", bound=AsyncStream[Any]) if TYPE_CHECKING: - from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT + from httpx._config import ( + DEFAULT_TIMEOUT_CONFIG, # pyright: ignore[reportPrivateImportUsage] + ) + + HTTPX_DEFAULT_TIMEOUT = DEFAULT_TIMEOUT_CONFIG else: try: from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT @@ -119,6 +119,7 @@ class PageInfo: url: URL | NotGiven params: Query | NotGiven + json: Body | NotGiven @overload def __init__( @@ -134,19 +135,30 @@ def __init__( params: Query, ) -> None: ... + @overload + def __init__( + self, + *, + json: Body, + ) -> None: ... + def __init__( self, *, - url: URL | NotGiven = NOT_GIVEN, - params: Query | NotGiven = NOT_GIVEN, + url: URL | NotGiven = not_given, + json: Body | NotGiven = not_given, + params: Query | NotGiven = not_given, ) -> None: self.url = url + self.json = json self.params = params @override def __repr__(self) -> str: if self.url: return f"{self.__class__.__name__}(url={self.url})" + if self.json: + return f"{self.__class__.__name__}(json={self.json})" return f"{self.__class__.__name__}(params={self.params})" @@ -195,6 +207,19 @@ def _info_to_options(self, info: PageInfo) -> FinalRequestOptions: options.url = str(url) return options + if not isinstance(info.json, NotGiven): + if not is_mapping(info.json): + raise TypeError("Pagination is only supported with mappings") + + if not options.json_data: + options.json_data = {**info.json} + else: + if not is_mapping(options.json_data): + raise TypeError("Pagination is only supported with mappings") + + options.json_data = {**options.json_data, **info.json} + return options + raise ValueError("Unexpected PageInfo state") @@ -207,6 +232,9 @@ def _set_private_attributes( model: Type[_T], options: FinalRequestOptions, ) -> None: + if (not PYDANTIC_V1) and getattr(self, "__pydantic_private__", None) is None: + self.__pydantic_private__ = {} + self._model = model self._client = client self._options = options @@ -292,6 +320,9 @@ def _set_private_attributes( client: AsyncAPIClient, options: FinalRequestOptions, ) -> None: + if (not PYDANTIC_V1) and getattr(self, "__pydantic_private__", None) is None: + self.__pydantic_private__ = {} + self._model = model self._client = client self._options = options @@ -331,9 +362,6 @@ class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]): _base_url: URL max_retries: int timeout: Union[float, Timeout, None] - _limits: httpx.Limits - _proxies: ProxiesTypes | None - _transport: Transport | AsyncTransport | None _strict_response_validation: bool _idempotency_header: str | None _default_stream_cls: type[_DefaultStreamT] | None = None @@ -346,9 +374,6 @@ def __init__( _strict_response_validation: bool, max_retries: int = DEFAULT_MAX_RETRIES, timeout: float | Timeout | None = DEFAULT_TIMEOUT, - limits: httpx.Limits, - transport: Transport | AsyncTransport | None, - proxies: ProxiesTypes | None, custom_headers: Mapping[str, str] | None = None, custom_query: Mapping[str, object] | None = None, ) -> None: @@ -356,9 +381,6 @@ def __init__( self._base_url = self._enforce_trailing_slash(URL(base_url)) self.max_retries = max_retries self.timeout = timeout - self._limits = limits - self._proxies = proxies - self._transport = transport self._custom_headers = custom_headers or {} self._custom_query = custom_query or {} self._strict_response_validation = _strict_response_validation @@ -415,13 +437,20 @@ def _build_headers(self, options: FinalRequestOptions, *, retries_taken: int = 0 headers = httpx.Headers(headers_dict) idempotency_header = self._idempotency_header - if idempotency_header and options.method.lower() != "get" and idempotency_header not in headers: - headers[idempotency_header] = options.idempotency_key or self._idempotency_key() + if idempotency_header and options.idempotency_key and idempotency_header not in headers: + headers[idempotency_header] = options.idempotency_key - # Don't set the retry count header if it was already set or removed by the caller. We check + # Don't set these headers if they were already set or removed by the caller. We check # `custom_headers`, which can contain `Omit()`, instead of `headers` to account for the removal case. - if "x-stainless-retry-count" not in (header.lower() for header in custom_headers): + lower_custom_headers = [header.lower() for header in custom_headers] + if "x-stainless-retry-count" not in lower_custom_headers: headers["x-stainless-retry-count"] = str(retries_taken) + if "x-stainless-read-timeout" not in lower_custom_headers: + timeout = self.timeout if isinstance(options.timeout, NotGiven) else options.timeout + if isinstance(timeout, Timeout): + timeout = timeout.read + if timeout is not None: + headers["x-stainless-read-timeout"] = str(timeout) return headers @@ -500,6 +529,18 @@ def _build_request( # work around https://github.com/encode/httpx/discussions/2880 kwargs["extensions"] = {"sni_hostname": prepared_url.host.replace("_", "-")} + is_body_allowed = options.method.lower() != "get" + + if is_body_allowed: + if isinstance(json_data, bytes): + kwargs["content"] = json_data + else: + kwargs["json"] = json_data if is_given(json_data) else None + kwargs["files"] = files + else: + headers.pop("Content-Type", None) + kwargs.pop("data", None) + # TODO: report this error to httpx return self._client.build_request( # pyright: ignore[reportUnknownMemberType] headers=headers, @@ -511,8 +552,6 @@ def _build_request( # so that passing a `TypedDict` doesn't cause an error. # https://github.com/microsoft/pyright/issues/3526#event-6715453066 params=self.qs.stringify(cast(Mapping[str, Any], params)) if params else None, - json=json_data, - files=files, **kwargs, ) @@ -556,7 +595,7 @@ def _maybe_override_cast_to(self, cast_to: type[ResponseT], options: FinalReques # we internally support defining a temporary header to override the # default `cast_to` type for use with `.with_raw_response` and `.with_streaming_response` # see _response.py for implementation details - override_cast_to = headers.pop(OVERRIDE_CAST_TO_HEADER, NOT_GIVEN) + override_cast_to = headers.pop(OVERRIDE_CAST_TO_HEADER, not_given) if is_given(override_cast_to): options.headers = headers return cast(Type[ResponseT], override_cast_to) @@ -786,47 +825,12 @@ def __init__( version: str, base_url: str | URL, max_retries: int = DEFAULT_MAX_RETRIES, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, - transport: Transport | None = None, - proxies: ProxiesTypes | None = None, - limits: Limits | None = None, + timeout: float | Timeout | None | NotGiven = not_given, http_client: httpx.Client | None = None, custom_headers: Mapping[str, str] | None = None, custom_query: Mapping[str, object] | None = None, _strict_response_validation: bool, ) -> None: - kwargs: dict[str, Any] = {} - if limits is not None: - warnings.warn( - "The `connection_pool_limits` argument is deprecated. The `http_client` argument should be passed instead", - category=DeprecationWarning, - stacklevel=3, - ) - if http_client is not None: - raise ValueError("The `http_client` argument is mutually exclusive with `connection_pool_limits`") - else: - limits = DEFAULT_CONNECTION_LIMITS - - if transport is not None: - kwargs["transport"] = transport - warnings.warn( - "The `transport` argument is deprecated. The `http_client` argument should be passed instead", - category=DeprecationWarning, - stacklevel=3, - ) - if http_client is not None: - raise ValueError("The `http_client` argument is mutually exclusive with `transport`") - - if proxies is not None: - kwargs["proxies"] = proxies - warnings.warn( - "The `proxies` argument is deprecated. The `http_client` argument should be passed instead", - category=DeprecationWarning, - stacklevel=3, - ) - if http_client is not None: - raise ValueError("The `http_client` argument is mutually exclusive with `proxies`") - if not is_given(timeout): # if the user passed in a custom http client with a non-default # timeout set then we use that timeout. @@ -847,12 +851,9 @@ def __init__( super().__init__( version=version, - limits=limits, # cast to a valid type because mypy doesn't understand our type narrowing timeout=cast(Timeout, timeout), - proxies=proxies, base_url=base_url, - transport=transport, max_retries=max_retries, custom_query=custom_query, custom_headers=custom_headers, @@ -862,9 +863,6 @@ def __init__( base_url=base_url, # cast to a valid type because mypy doesn't understand our type narrowing timeout=cast(Timeout, timeout), - limits=limits, - follow_redirects=True, - **kwargs, # type: ignore ) def is_closed(self) -> bool: @@ -914,7 +912,6 @@ def request( self, cast_to: Type[ResponseT], options: FinalRequestOptions, - remaining_retries: Optional[int] = None, *, stream: Literal[True], stream_cls: Type[_StreamT], @@ -925,7 +922,6 @@ def request( self, cast_to: Type[ResponseT], options: FinalRequestOptions, - remaining_retries: Optional[int] = None, *, stream: Literal[False] = False, ) -> ResponseT: ... @@ -935,7 +931,6 @@ def request( self, cast_to: Type[ResponseT], options: FinalRequestOptions, - remaining_retries: Optional[int] = None, *, stream: bool = False, stream_cls: Type[_StreamT] | None = None, @@ -945,121 +940,112 @@ def request( self, cast_to: Type[ResponseT], options: FinalRequestOptions, - remaining_retries: Optional[int] = None, *, stream: bool = False, stream_cls: type[_StreamT] | None = None, ) -> ResponseT | _StreamT: - if remaining_retries is not None: - retries_taken = options.get_max_retries(self.max_retries) - remaining_retries - else: - retries_taken = 0 - - return self._request( - cast_to=cast_to, - options=options, - stream=stream, - stream_cls=stream_cls, - retries_taken=retries_taken, - ) + cast_to = self._maybe_override_cast_to(cast_to, options) - def _request( - self, - *, - cast_to: Type[ResponseT], - options: FinalRequestOptions, - retries_taken: int, - stream: bool, - stream_cls: type[_StreamT] | None, - ) -> ResponseT | _StreamT: # create a copy of the options we were given so that if the # options are mutated later & we then retry, the retries are # given the original options input_options = model_copy(options) + if input_options.idempotency_key is None and input_options.method.lower() != "get": + # ensure the idempotency key is reused between requests + input_options.idempotency_key = self._idempotency_key() - cast_to = self._maybe_override_cast_to(cast_to, options) - options = self._prepare_options(options) + response: httpx.Response | None = None + max_retries = input_options.get_max_retries(self.max_retries) - remaining_retries = options.get_max_retries(self.max_retries) - retries_taken - request = self._build_request(options, retries_taken=retries_taken) - self._prepare_request(request) + retries_taken = 0 + for retries_taken in range(max_retries + 1): + options = model_copy(input_options) + options = self._prepare_options(options) - kwargs: HttpxSendArgs = {} - if self.custom_auth is not None: - kwargs["auth"] = self.custom_auth + remaining_retries = max_retries - retries_taken + request = self._build_request(options, retries_taken=retries_taken) + self._prepare_request(request) - log.debug("Sending HTTP Request: %s %s", request.method, request.url) - - try: - response = self._client.send( - request, - stream=stream or self._should_stream_response_body(request=request), - **kwargs, - ) - except httpx.TimeoutException as err: - log.debug("Encountered httpx.TimeoutException", exc_info=True) + kwargs: HttpxSendArgs = {} + if self.custom_auth is not None: + kwargs["auth"] = self.custom_auth - if remaining_retries > 0: - return self._retry_request( - input_options, - cast_to, - retries_taken=retries_taken, - stream=stream, - stream_cls=stream_cls, - response_headers=None, - ) + if options.follow_redirects is not None: + kwargs["follow_redirects"] = options.follow_redirects - log.debug("Raising timeout error") - raise APITimeoutError(request=request) from err - except Exception as err: - log.debug("Encountered Exception", exc_info=True) + log.debug("Sending HTTP Request: %s %s", request.method, request.url) - if remaining_retries > 0: - return self._retry_request( - input_options, - cast_to, - retries_taken=retries_taken, - stream=stream, - stream_cls=stream_cls, - response_headers=None, + response = None + try: + response = self._client.send( + request, + stream=stream or self._should_stream_response_body(request=request), + **kwargs, ) + except httpx.TimeoutException as err: + log.debug("Encountered httpx.TimeoutException", exc_info=True) + + if remaining_retries > 0: + self._sleep_for_retry( + retries_taken=retries_taken, + max_retries=max_retries, + options=input_options, + response=None, + ) + continue + + log.debug("Raising timeout error") + raise APITimeoutError(request=request) from err + except Exception as err: + log.debug("Encountered Exception", exc_info=True) + + if remaining_retries > 0: + self._sleep_for_retry( + retries_taken=retries_taken, + max_retries=max_retries, + options=input_options, + response=None, + ) + continue + + log.debug("Raising connection error") + raise APIConnectionError(request=request) from err + + log.debug( + 'HTTP Response: %s %s "%i %s" %s', + request.method, + request.url, + response.status_code, + response.reason_phrase, + response.headers, + ) - log.debug("Raising connection error") - raise APIConnectionError(request=request) from err - - log.debug( - 'HTTP Response: %s %s "%i %s" %s', - request.method, - request.url, - response.status_code, - response.reason_phrase, - response.headers, - ) + try: + response.raise_for_status() + except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code + log.debug("Encountered httpx.HTTPStatusError", exc_info=True) + + if remaining_retries > 0 and self._should_retry(err.response): + err.response.close() + self._sleep_for_retry( + retries_taken=retries_taken, + max_retries=max_retries, + options=input_options, + response=response, + ) + continue - try: - response.raise_for_status() - except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code - log.debug("Encountered httpx.HTTPStatusError", exc_info=True) - - if remaining_retries > 0 and self._should_retry(err.response): - err.response.close() - return self._retry_request( - input_options, - cast_to, - retries_taken=retries_taken, - response_headers=err.response.headers, - stream=stream, - stream_cls=stream_cls, - ) + # If the response is streamed then we need to explicitly read the response + # to completion before attempting to access the response text. + if not err.response.is_closed: + err.response.read() - # If the response is streamed then we need to explicitly read the response - # to completion before attempting to access the response text. - if not err.response.is_closed: - err.response.read() + log.debug("Re-raising status error") + raise self._make_status_error_from_response(err.response) from None - log.debug("Re-raising status error") - raise self._make_status_error_from_response(err.response) from None + break + assert response is not None, "could not resolve response (should never happen)" return self._process_response( cast_to=cast_to, options=options, @@ -1069,37 +1055,20 @@ def _request( retries_taken=retries_taken, ) - def _retry_request( - self, - options: FinalRequestOptions, - cast_to: Type[ResponseT], - *, - retries_taken: int, - response_headers: httpx.Headers | None, - stream: bool, - stream_cls: type[_StreamT] | None, - ) -> ResponseT | _StreamT: - remaining_retries = options.get_max_retries(self.max_retries) - retries_taken + def _sleep_for_retry( + self, *, retries_taken: int, max_retries: int, options: FinalRequestOptions, response: httpx.Response | None + ) -> None: + remaining_retries = max_retries - retries_taken if remaining_retries == 1: log.debug("1 retry left") else: log.debug("%i retries left", remaining_retries) - timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers) + timeout = self._calculate_retry_timeout(remaining_retries, options, response.headers if response else None) log.info("Retrying request to %s in %f seconds", options.url, timeout) - # In a synchronous context we are blocking the entire thread. Up to the library user to run the client in a - # different thread if necessary. time.sleep(timeout) - return self._request( - options=options, - cast_to=cast_to, - retries_taken=retries_taken + 1, - stream=stream, - stream_cls=stream_cls, - ) - def _process_response( self, *, @@ -1112,7 +1081,14 @@ def _process_response( ) -> ResponseT: origin = get_origin(cast_to) or cast_to - if inspect.isclass(origin) and issubclass(origin, BaseAPIResponse): + if ( + inspect.isclass(origin) + and issubclass(origin, BaseAPIResponse) + # we only want to actually return the custom BaseAPIResponse class if we're + # returning the raw response, or if we're not streaming SSE, as if we're streaming + # SSE then `cast_to` doesn't actively reflect the type we need to parse into + and (not stream or bool(response.request.headers.get(RAW_RESPONSE_HEADER))) + ): if not issubclass(origin, APIResponse): raise TypeError(f"API Response types must subclass {APIResponse}; Received {origin}") @@ -1271,9 +1247,12 @@ def patch( *, cast_to: Type[ResponseT], body: Body | None = None, + files: RequestFiles | None = None, options: RequestOptions = {}, ) -> ResponseT: - opts = FinalRequestOptions.construct(method="patch", url=path, json_data=body, **options) + opts = FinalRequestOptions.construct( + method="patch", url=path, json_data=body, files=to_httpx_files(files), **options + ) return self.request(cast_to, opts) def put( @@ -1323,6 +1302,24 @@ def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) +try: + import httpx_aiohttp +except ImportError: + + class _DefaultAioHttpClient(httpx.AsyncClient): + def __init__(self, **_kwargs: Any) -> None: + raise RuntimeError("To use the aiohttp client you must have installed the package with the `aiohttp` extra") +else: + + class _DefaultAioHttpClient(httpx_aiohttp.HttpxAiohttpClient): # type: ignore + def __init__(self, **kwargs: Any) -> None: + kwargs.setdefault("timeout", DEFAULT_TIMEOUT) + kwargs.setdefault("limits", DEFAULT_CONNECTION_LIMITS) + kwargs.setdefault("follow_redirects", True) + + super().__init__(**kwargs) + + if TYPE_CHECKING: DefaultAsyncHttpxClient = httpx.AsyncClient """An alias to `httpx.AsyncClient` that provides the same defaults that this SDK @@ -1331,8 +1328,12 @@ def __init__(self, **kwargs: Any) -> None: This is useful because overriding the `http_client` with your own instance of `httpx.AsyncClient` will result in httpx's defaults being used, not ours. """ + + DefaultAioHttpClient = httpx.AsyncClient + """An alias to `httpx.AsyncClient` that changes the default HTTP transport to `aiohttp`.""" else: DefaultAsyncHttpxClient = _DefaultAsyncHttpxClient + DefaultAioHttpClient = _DefaultAioHttpClient class AsyncHttpxClientWrapper(DefaultAsyncHttpxClient): @@ -1358,46 +1359,11 @@ def __init__( base_url: str | URL, _strict_response_validation: bool, max_retries: int = DEFAULT_MAX_RETRIES, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, - transport: AsyncTransport | None = None, - proxies: ProxiesTypes | None = None, - limits: Limits | None = None, + timeout: float | Timeout | None | NotGiven = not_given, http_client: httpx.AsyncClient | None = None, custom_headers: Mapping[str, str] | None = None, custom_query: Mapping[str, object] | None = None, ) -> None: - kwargs: dict[str, Any] = {} - if limits is not None: - warnings.warn( - "The `connection_pool_limits` argument is deprecated. The `http_client` argument should be passed instead", - category=DeprecationWarning, - stacklevel=3, - ) - if http_client is not None: - raise ValueError("The `http_client` argument is mutually exclusive with `connection_pool_limits`") - else: - limits = DEFAULT_CONNECTION_LIMITS - - if transport is not None: - kwargs["transport"] = transport - warnings.warn( - "The `transport` argument is deprecated. The `http_client` argument should be passed instead", - category=DeprecationWarning, - stacklevel=3, - ) - if http_client is not None: - raise ValueError("The `http_client` argument is mutually exclusive with `transport`") - - if proxies is not None: - kwargs["proxies"] = proxies - warnings.warn( - "The `proxies` argument is deprecated. The `http_client` argument should be passed instead", - category=DeprecationWarning, - stacklevel=3, - ) - if http_client is not None: - raise ValueError("The `http_client` argument is mutually exclusive with `proxies`") - if not is_given(timeout): # if the user passed in a custom http client with a non-default # timeout set then we use that timeout. @@ -1419,11 +1385,8 @@ def __init__( super().__init__( version=version, base_url=base_url, - limits=limits, # cast to a valid type because mypy doesn't understand our type narrowing timeout=cast(Timeout, timeout), - proxies=proxies, - transport=transport, max_retries=max_retries, custom_query=custom_query, custom_headers=custom_headers, @@ -1433,9 +1396,6 @@ def __init__( base_url=base_url, # cast to a valid type because mypy doesn't understand our type narrowing timeout=cast(Timeout, timeout), - limits=limits, - follow_redirects=True, - **kwargs, # type: ignore ) def is_closed(self) -> bool: @@ -1484,7 +1444,6 @@ async def request( options: FinalRequestOptions, *, stream: Literal[False] = False, - remaining_retries: Optional[int] = None, ) -> ResponseT: ... @overload @@ -1495,7 +1454,6 @@ async def request( *, stream: Literal[True], stream_cls: type[_AsyncStreamT], - remaining_retries: Optional[int] = None, ) -> _AsyncStreamT: ... @overload @@ -1506,7 +1464,6 @@ async def request( *, stream: bool, stream_cls: type[_AsyncStreamT] | None = None, - remaining_retries: Optional[int] = None, ) -> ResponseT | _AsyncStreamT: ... async def request( @@ -1516,116 +1473,114 @@ async def request( *, stream: bool = False, stream_cls: type[_AsyncStreamT] | None = None, - remaining_retries: Optional[int] = None, - ) -> ResponseT | _AsyncStreamT: - if remaining_retries is not None: - retries_taken = options.get_max_retries(self.max_retries) - remaining_retries - else: - retries_taken = 0 - - return await self._request( - cast_to=cast_to, - options=options, - stream=stream, - stream_cls=stream_cls, - retries_taken=retries_taken, - ) - - async def _request( - self, - cast_to: Type[ResponseT], - options: FinalRequestOptions, - *, - stream: bool, - stream_cls: type[_AsyncStreamT] | None, - retries_taken: int, ) -> ResponseT | _AsyncStreamT: if self._platform is None: # `get_platform` can make blocking IO calls so we # execute it earlier while we are in an async context self._platform = await asyncify(get_platform)() + cast_to = self._maybe_override_cast_to(cast_to, options) + # create a copy of the options we were given so that if the # options are mutated later & we then retry, the retries are # given the original options input_options = model_copy(options) + if input_options.idempotency_key is None and input_options.method.lower() != "get": + # ensure the idempotency key is reused between requests + input_options.idempotency_key = self._idempotency_key() - cast_to = self._maybe_override_cast_to(cast_to, options) - options = await self._prepare_options(options) + response: httpx.Response | None = None + max_retries = input_options.get_max_retries(self.max_retries) - remaining_retries = options.get_max_retries(self.max_retries) - retries_taken - request = self._build_request(options, retries_taken=retries_taken) - await self._prepare_request(request) + retries_taken = 0 + for retries_taken in range(max_retries + 1): + options = model_copy(input_options) + options = await self._prepare_options(options) - kwargs: HttpxSendArgs = {} - if self.custom_auth is not None: - kwargs["auth"] = self.custom_auth + remaining_retries = max_retries - retries_taken + request = self._build_request(options, retries_taken=retries_taken) + await self._prepare_request(request) - try: - response = await self._client.send( - request, - stream=stream or self._should_stream_response_body(request=request), - **kwargs, - ) - except httpx.TimeoutException as err: - log.debug("Encountered httpx.TimeoutException", exc_info=True) + kwargs: HttpxSendArgs = {} + if self.custom_auth is not None: + kwargs["auth"] = self.custom_auth - if remaining_retries > 0: - return await self._retry_request( - input_options, - cast_to, - retries_taken=retries_taken, - stream=stream, - stream_cls=stream_cls, - response_headers=None, - ) + if options.follow_redirects is not None: + kwargs["follow_redirects"] = options.follow_redirects - log.debug("Raising timeout error") - raise APITimeoutError(request=request) from err - except Exception as err: - log.debug("Encountered Exception", exc_info=True) + log.debug("Sending HTTP Request: %s %s", request.method, request.url) - if remaining_retries > 0: - return await self._retry_request( - input_options, - cast_to, - retries_taken=retries_taken, - stream=stream, - stream_cls=stream_cls, - response_headers=None, + response = None + try: + response = await self._client.send( + request, + stream=stream or self._should_stream_response_body(request=request), + **kwargs, ) + except httpx.TimeoutException as err: + log.debug("Encountered httpx.TimeoutException", exc_info=True) + + if remaining_retries > 0: + await self._sleep_for_retry( + retries_taken=retries_taken, + max_retries=max_retries, + options=input_options, + response=None, + ) + continue + + log.debug("Raising timeout error") + raise APITimeoutError(request=request) from err + except Exception as err: + log.debug("Encountered Exception", exc_info=True) + + if remaining_retries > 0: + await self._sleep_for_retry( + retries_taken=retries_taken, + max_retries=max_retries, + options=input_options, + response=None, + ) + continue + + log.debug("Raising connection error") + raise APIConnectionError(request=request) from err + + log.debug( + 'HTTP Response: %s %s "%i %s" %s', + request.method, + request.url, + response.status_code, + response.reason_phrase, + response.headers, + ) - log.debug("Raising connection error") - raise APIConnectionError(request=request) from err + try: + response.raise_for_status() + except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code + log.debug("Encountered httpx.HTTPStatusError", exc_info=True) + + if remaining_retries > 0 and self._should_retry(err.response): + await err.response.aclose() + await self._sleep_for_retry( + retries_taken=retries_taken, + max_retries=max_retries, + options=input_options, + response=response, + ) + continue - log.debug( - 'HTTP Request: %s %s "%i %s"', request.method, request.url, response.status_code, response.reason_phrase - ) + # If the response is streamed then we need to explicitly read the response + # to completion before attempting to access the response text. + if not err.response.is_closed: + await err.response.aread() - try: - response.raise_for_status() - except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code - log.debug("Encountered httpx.HTTPStatusError", exc_info=True) - - if remaining_retries > 0 and self._should_retry(err.response): - await err.response.aclose() - return await self._retry_request( - input_options, - cast_to, - retries_taken=retries_taken, - response_headers=err.response.headers, - stream=stream, - stream_cls=stream_cls, - ) - - # If the response is streamed then we need to explicitly read the response - # to completion before attempting to access the response text. - if not err.response.is_closed: - await err.response.aread() + log.debug("Re-raising status error") + raise self._make_status_error_from_response(err.response) from None - log.debug("Re-raising status error") - raise self._make_status_error_from_response(err.response) from None + break + assert response is not None, "could not resolve response (should never happen)" return await self._process_response( cast_to=cast_to, options=options, @@ -1635,35 +1590,20 @@ async def _request( retries_taken=retries_taken, ) - async def _retry_request( - self, - options: FinalRequestOptions, - cast_to: Type[ResponseT], - *, - retries_taken: int, - response_headers: httpx.Headers | None, - stream: bool, - stream_cls: type[_AsyncStreamT] | None, - ) -> ResponseT | _AsyncStreamT: - remaining_retries = options.get_max_retries(self.max_retries) - retries_taken + async def _sleep_for_retry( + self, *, retries_taken: int, max_retries: int, options: FinalRequestOptions, response: httpx.Response | None + ) -> None: + remaining_retries = max_retries - retries_taken if remaining_retries == 1: log.debug("1 retry left") else: log.debug("%i retries left", remaining_retries) - timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers) + timeout = self._calculate_retry_timeout(remaining_retries, options, response.headers if response else None) log.info("Retrying request to %s in %f seconds", options.url, timeout) await anyio.sleep(timeout) - return await self._request( - options=options, - cast_to=cast_to, - retries_taken=retries_taken + 1, - stream=stream, - stream_cls=stream_cls, - ) - async def _process_response( self, *, @@ -1676,7 +1616,14 @@ async def _process_response( ) -> ResponseT: origin = get_origin(cast_to) or cast_to - if inspect.isclass(origin) and issubclass(origin, BaseAPIResponse): + if ( + inspect.isclass(origin) + and issubclass(origin, BaseAPIResponse) + # we only want to actually return the custom BaseAPIResponse class if we're + # returning the raw response, or if we're not streaming SSE, as if we're streaming + # SSE then `cast_to` doesn't actively reflect the type we need to parse into + and (not stream or bool(response.request.headers.get(RAW_RESPONSE_HEADER))) + ): if not issubclass(origin, AsyncAPIResponse): raise TypeError(f"API Response types must subclass {AsyncAPIResponse}; Received {origin}") @@ -1823,9 +1770,12 @@ async def patch( *, cast_to: Type[ResponseT], body: Body | None = None, + files: RequestFiles | None = None, options: RequestOptions = {}, ) -> ResponseT: - opts = FinalRequestOptions.construct(method="patch", url=path, json_data=body, **options) + opts = FinalRequestOptions.construct( + method="patch", url=path, json_data=body, files=await async_to_httpx_files(files), **options + ) return await self.request(cast_to, opts) async def put( @@ -1874,8 +1824,8 @@ def make_request_options( extra_query: Query | None = None, extra_body: Body | None = None, idempotency_key: str | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - post_parser: PostParser | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + post_parser: PostParser | NotGiven = not_given, ) -> RequestOptions: """Create a dict of type RequestOptions without keys of NotGiven values.""" options: RequestOptions = {} diff --git a/src/brainbase/_client.py b/src/brainbase/_client.py index c42e10d6..541ea815 100644 --- a/src/brainbase/_client.py +++ b/src/brainbase/_client.py @@ -3,7 +3,7 @@ from __future__ import annotations import os -from typing import Any, Union, Mapping +from typing import TYPE_CHECKING, Any, Mapping from typing_extensions import Self, override import httpx @@ -11,18 +11,16 @@ from . import _exceptions from ._qs import Querystring from ._types import ( - NOT_GIVEN, Omit, Timeout, NotGiven, Transport, ProxiesTypes, RequestOptions, + not_given, ) -from ._utils import ( - is_given, - get_async_library, -) +from ._utils import is_given, get_async_library +from ._compat import cached_property from ._version import __version__ from ._streaming import Stream as Stream, AsyncStream as AsyncStream from ._exceptions import APIStatusError, BrainbaseError @@ -31,7 +29,10 @@ SyncAPIClient, AsyncAPIClient, ) -from .resources.workers import workers + +if TYPE_CHECKING: + from .resources import workers + from .resources.workers.workers import WorkersResource, AsyncWorkersResource __all__ = [ "Timeout", @@ -46,10 +47,6 @@ class Brainbase(SyncAPIClient): - workers: workers.WorkersResource - with_raw_response: BrainbaseWithRawResponse - with_streaming_response: BrainbaseWithStreamedResponse - # client options api_key: str @@ -58,7 +55,7 @@ def __init__( *, api_key: str | None = None, base_url: str | httpx.URL | None = None, - timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, + timeout: float | Timeout | None | NotGiven = not_given, max_retries: int = DEFAULT_MAX_RETRIES, default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, @@ -76,7 +73,7 @@ def __init__( # part of our public interface in the future. _strict_response_validation: bool = False, ) -> None: - """Construct a new synchronous brainbase client instance. + """Construct a new synchronous Brainbase client instance. This automatically infers the `api_key` argument from the `API_KEY` environment variable if it is not provided. """ @@ -104,9 +101,19 @@ def __init__( _strict_response_validation=_strict_response_validation, ) - self.workers = workers.WorkersResource(self) - self.with_raw_response = BrainbaseWithRawResponse(self) - self.with_streaming_response = BrainbaseWithStreamedResponse(self) + @cached_property + def workers(self) -> WorkersResource: + from .resources.workers import WorkersResource + + return WorkersResource(self) + + @cached_property + def with_raw_response(self) -> BrainbaseWithRawResponse: + return BrainbaseWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> BrainbaseWithStreamedResponse: + return BrainbaseWithStreamedResponse(self) @property @override @@ -133,9 +140,9 @@ def copy( *, api_key: str | None = None, base_url: str | httpx.URL | None = None, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | Timeout | None | NotGiven = not_given, http_client: httpx.Client | None = None, - max_retries: int | NotGiven = NOT_GIVEN, + max_retries: int | NotGiven = not_given, default_headers: Mapping[str, str] | None = None, set_default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, @@ -214,10 +221,6 @@ def _make_status_error( class AsyncBrainbase(AsyncAPIClient): - workers: workers.AsyncWorkersResource - with_raw_response: AsyncBrainbaseWithRawResponse - with_streaming_response: AsyncBrainbaseWithStreamedResponse - # client options api_key: str @@ -226,7 +229,7 @@ def __init__( *, api_key: str | None = None, base_url: str | httpx.URL | None = None, - timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, + timeout: float | Timeout | None | NotGiven = not_given, max_retries: int = DEFAULT_MAX_RETRIES, default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, @@ -244,7 +247,7 @@ def __init__( # part of our public interface in the future. _strict_response_validation: bool = False, ) -> None: - """Construct a new async brainbase client instance. + """Construct a new async AsyncBrainbase client instance. This automatically infers the `api_key` argument from the `API_KEY` environment variable if it is not provided. """ @@ -272,9 +275,19 @@ def __init__( _strict_response_validation=_strict_response_validation, ) - self.workers = workers.AsyncWorkersResource(self) - self.with_raw_response = AsyncBrainbaseWithRawResponse(self) - self.with_streaming_response = AsyncBrainbaseWithStreamedResponse(self) + @cached_property + def workers(self) -> AsyncWorkersResource: + from .resources.workers import AsyncWorkersResource + + return AsyncWorkersResource(self) + + @cached_property + def with_raw_response(self) -> AsyncBrainbaseWithRawResponse: + return AsyncBrainbaseWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> AsyncBrainbaseWithStreamedResponse: + return AsyncBrainbaseWithStreamedResponse(self) @property @override @@ -301,9 +314,9 @@ def copy( *, api_key: str | None = None, base_url: str | httpx.URL | None = None, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | Timeout | None | NotGiven = not_given, http_client: httpx.AsyncClient | None = None, - max_retries: int | NotGiven = NOT_GIVEN, + max_retries: int | NotGiven = not_given, default_headers: Mapping[str, str] | None = None, set_default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, @@ -382,23 +395,55 @@ def _make_status_error( class BrainbaseWithRawResponse: + _client: Brainbase + def __init__(self, client: Brainbase) -> None: - self.workers = workers.WorkersResourceWithRawResponse(client.workers) + self._client = client + + @cached_property + def workers(self) -> workers.WorkersResourceWithRawResponse: + from .resources.workers import WorkersResourceWithRawResponse + + return WorkersResourceWithRawResponse(self._client.workers) class AsyncBrainbaseWithRawResponse: + _client: AsyncBrainbase + def __init__(self, client: AsyncBrainbase) -> None: - self.workers = workers.AsyncWorkersResourceWithRawResponse(client.workers) + self._client = client + + @cached_property + def workers(self) -> workers.AsyncWorkersResourceWithRawResponse: + from .resources.workers import AsyncWorkersResourceWithRawResponse + + return AsyncWorkersResourceWithRawResponse(self._client.workers) class BrainbaseWithStreamedResponse: + _client: Brainbase + def __init__(self, client: Brainbase) -> None: - self.workers = workers.WorkersResourceWithStreamingResponse(client.workers) + self._client = client + + @cached_property + def workers(self) -> workers.WorkersResourceWithStreamingResponse: + from .resources.workers import WorkersResourceWithStreamingResponse + + return WorkersResourceWithStreamingResponse(self._client.workers) class AsyncBrainbaseWithStreamedResponse: + _client: AsyncBrainbase + def __init__(self, client: AsyncBrainbase) -> None: - self.workers = workers.AsyncWorkersResourceWithStreamingResponse(client.workers) + self._client = client + + @cached_property + def workers(self) -> workers.AsyncWorkersResourceWithStreamingResponse: + from .resources.workers import AsyncWorkersResourceWithStreamingResponse + + return AsyncWorkersResourceWithStreamingResponse(self._client.workers) Client = Brainbase diff --git a/src/brainbase/_compat.py b/src/brainbase/_compat.py index 92d9ee61..bdef67f0 100644 --- a/src/brainbase/_compat.py +++ b/src/brainbase/_compat.py @@ -12,14 +12,13 @@ _T = TypeVar("_T") _ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel) -# --------------- Pydantic v2 compatibility --------------- +# --------------- Pydantic v2, v3 compatibility --------------- # Pyright incorrectly reports some of our functions as overriding a method when they don't # pyright: reportIncompatibleMethodOverride=false -PYDANTIC_V2 = pydantic.VERSION.startswith("2.") +PYDANTIC_V1 = pydantic.VERSION.startswith("1.") -# v1 re-exports if TYPE_CHECKING: def parse_date(value: date | StrBytesIntFloat) -> date: # noqa: ARG001 @@ -44,90 +43,92 @@ def is_typeddict(type_: type[Any]) -> bool: # noqa: ARG001 ... else: - if PYDANTIC_V2: - from pydantic.v1.typing import ( + # v1 re-exports + if PYDANTIC_V1: + from pydantic.typing import ( get_args as get_args, is_union as is_union, get_origin as get_origin, is_typeddict as is_typeddict, is_literal_type as is_literal_type, ) - from pydantic.v1.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime + from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime else: - from pydantic.typing import ( + from ._utils import ( get_args as get_args, is_union as is_union, get_origin as get_origin, + parse_date as parse_date, is_typeddict as is_typeddict, + parse_datetime as parse_datetime, is_literal_type as is_literal_type, ) - from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime # refactored config if TYPE_CHECKING: from pydantic import ConfigDict as ConfigDict else: - if PYDANTIC_V2: - from pydantic import ConfigDict - else: + if PYDANTIC_V1: # TODO: provide an error message here? ConfigDict = None + else: + from pydantic import ConfigDict as ConfigDict # renamed methods / properties def parse_obj(model: type[_ModelT], value: object) -> _ModelT: - if PYDANTIC_V2: - return model.model_validate(value) - else: + if PYDANTIC_V1: return cast(_ModelT, model.parse_obj(value)) # pyright: ignore[reportDeprecated, reportUnnecessaryCast] + else: + return model.model_validate(value) def field_is_required(field: FieldInfo) -> bool: - if PYDANTIC_V2: - return field.is_required() - return field.required # type: ignore + if PYDANTIC_V1: + return field.required # type: ignore + return field.is_required() def field_get_default(field: FieldInfo) -> Any: value = field.get_default() - if PYDANTIC_V2: - from pydantic_core import PydanticUndefined - - if value == PydanticUndefined: - return None + if PYDANTIC_V1: return value + from pydantic_core import PydanticUndefined + + if value == PydanticUndefined: + return None return value def field_outer_type(field: FieldInfo) -> Any: - if PYDANTIC_V2: - return field.annotation - return field.outer_type_ # type: ignore + if PYDANTIC_V1: + return field.outer_type_ # type: ignore + return field.annotation def get_model_config(model: type[pydantic.BaseModel]) -> Any: - if PYDANTIC_V2: - return model.model_config - return model.__config__ # type: ignore + if PYDANTIC_V1: + return model.__config__ # type: ignore + return model.model_config def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]: - if PYDANTIC_V2: - return model.model_fields - return model.__fields__ # type: ignore + if PYDANTIC_V1: + return model.__fields__ # type: ignore + return model.model_fields def model_copy(model: _ModelT, *, deep: bool = False) -> _ModelT: - if PYDANTIC_V2: - return model.model_copy(deep=deep) - return model.copy(deep=deep) # type: ignore + if PYDANTIC_V1: + return model.copy(deep=deep) # type: ignore + return model.model_copy(deep=deep) def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str: - if PYDANTIC_V2: - return model.model_dump_json(indent=indent) - return model.json(indent=indent) # type: ignore + if PYDANTIC_V1: + return model.json(indent=indent) # type: ignore + return model.model_dump_json(indent=indent) def model_dump( @@ -139,14 +140,14 @@ def model_dump( warnings: bool = True, mode: Literal["json", "python"] = "python", ) -> dict[str, Any]: - if PYDANTIC_V2 or hasattr(model, "model_dump"): + if (not PYDANTIC_V1) or hasattr(model, "model_dump"): return model.model_dump( mode=mode, exclude=exclude, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, # warnings are not supported in Pydantic v1 - warnings=warnings if PYDANTIC_V2 else True, + warnings=True if PYDANTIC_V1 else warnings, ) return cast( "dict[str, Any]", @@ -159,9 +160,9 @@ def model_dump( def model_parse(model: type[_ModelT], data: Any) -> _ModelT: - if PYDANTIC_V2: - return model.model_validate(data) - return model.parse_obj(data) # pyright: ignore[reportDeprecated] + if PYDANTIC_V1: + return model.parse_obj(data) # pyright: ignore[reportDeprecated] + return model.model_validate(data) # generic models @@ -170,17 +171,16 @@ def model_parse(model: type[_ModelT], data: Any) -> _ModelT: class GenericModel(pydantic.BaseModel): ... else: - if PYDANTIC_V2: + if PYDANTIC_V1: + import pydantic.generics + + class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): ... + else: # there no longer needs to be a distinction in v2 but # we still have to create our own subclass to avoid # inconsistent MRO ordering errors class GenericModel(pydantic.BaseModel): ... - else: - import pydantic.generics - - class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): ... - # cached properties if TYPE_CHECKING: diff --git a/src/brainbase/_files.py b/src/brainbase/_files.py index 715cc207..cc14c14f 100644 --- a/src/brainbase/_files.py +++ b/src/brainbase/_files.py @@ -69,12 +69,12 @@ def _transform_file(file: FileTypes) -> HttpxFileTypes: return file if is_tuple_t(file): - return (file[0], _read_file_content(file[1]), *file[2:]) + return (file[0], read_file_content(file[1]), *file[2:]) raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple") -def _read_file_content(file: FileContent) -> HttpxFileContent: +def read_file_content(file: FileContent) -> HttpxFileContent: if isinstance(file, os.PathLike): return pathlib.Path(file).read_bytes() return file @@ -111,12 +111,12 @@ async def _async_transform_file(file: FileTypes) -> HttpxFileTypes: return file if is_tuple_t(file): - return (file[0], await _async_read_file_content(file[1]), *file[2:]) + return (file[0], await async_read_file_content(file[1]), *file[2:]) raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple") -async def _async_read_file_content(file: FileContent) -> HttpxFileContent: +async def async_read_file_content(file: FileContent) -> HttpxFileContent: if isinstance(file, os.PathLike): return await anyio.Path(file).read_bytes() diff --git a/src/brainbase/_models.py b/src/brainbase/_models.py index 12c34b7d..ca9500b2 100644 --- a/src/brainbase/_models.py +++ b/src/brainbase/_models.py @@ -2,9 +2,11 @@ import os import inspect -from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, cast +import weakref +from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, Optional, cast from datetime import date, datetime from typing_extensions import ( + List, Unpack, Literal, ClassVar, @@ -19,7 +21,6 @@ ) import pydantic -import pydantic.generics from pydantic.fields import FieldInfo from ._types import ( @@ -50,7 +51,7 @@ strip_annotated_type, ) from ._compat import ( - PYDANTIC_V2, + PYDANTIC_V1, ConfigDict, GenericModel as BaseGenericModel, get_args, @@ -65,7 +66,7 @@ from ._constants import RAW_RESPONSE_HEADER if TYPE_CHECKING: - from pydantic_core.core_schema import ModelField, LiteralSchema, ModelFieldsSchema + from pydantic_core.core_schema import ModelField, ModelSchema, LiteralSchema, ModelFieldsSchema __all__ = ["BaseModel", "GenericModel"] @@ -81,11 +82,7 @@ class _ConfigProtocol(Protocol): class BaseModel(pydantic.BaseModel): - if PYDANTIC_V2: - model_config: ClassVar[ConfigDict] = ConfigDict( - extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true")) - ) - else: + if PYDANTIC_V1: @property @override @@ -95,6 +92,10 @@ def model_fields_set(self) -> set[str]: class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated] extra: Any = pydantic.Extra.allow # type: ignore + else: + model_config: ClassVar[ConfigDict] = ConfigDict( + extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true")) + ) def to_dict( self, @@ -208,28 +209,32 @@ def construct( # pyright: ignore[reportIncompatibleMethodOverride] else: fields_values[name] = field_get_default(field) + extra_field_type = _get_extra_fields_type(__cls) + _extra = {} for key, value in values.items(): if key not in model_fields: - if PYDANTIC_V2: - _extra[key] = value - else: + parsed = construct_type(value=value, type_=extra_field_type) if extra_field_type is not None else value + + if PYDANTIC_V1: _fields_set.add(key) - fields_values[key] = value + fields_values[key] = parsed + else: + _extra[key] = parsed object.__setattr__(m, "__dict__", fields_values) - if PYDANTIC_V2: - # these properties are copied from Pydantic's `model_construct()` method - object.__setattr__(m, "__pydantic_private__", None) - object.__setattr__(m, "__pydantic_extra__", _extra) - object.__setattr__(m, "__pydantic_fields_set__", _fields_set) - else: + if PYDANTIC_V1: # init_private_attributes() does not exist in v2 m._init_private_attributes() # type: ignore # copied from Pydantic v1's `construct()` method object.__setattr__(m, "__fields_set__", _fields_set) + else: + # these properties are copied from Pydantic's `model_construct()` method + object.__setattr__(m, "__pydantic_private__", None) + object.__setattr__(m, "__pydantic_extra__", _extra) + object.__setattr__(m, "__pydantic_fields_set__", _fields_set) return m @@ -239,7 +244,7 @@ def construct( # pyright: ignore[reportIncompatibleMethodOverride] # although not in practice model_construct = construct - if not PYDANTIC_V2: + if PYDANTIC_V1: # we define aliases for some of the new pydantic v2 methods so # that we can just document these methods without having to specify # a specific pydantic version as some users may not know which @@ -252,13 +257,15 @@ def model_dump( mode: Literal["json", "python"] | str = "python", include: IncEx | None = None, exclude: IncEx | None = None, - by_alias: bool = False, + context: Any | None = None, + by_alias: bool | None = None, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, + exclude_computed_fields: bool = False, round_trip: bool = False, warnings: bool | Literal["none", "warn", "error"] = True, - context: dict[str, Any] | None = None, + fallback: Callable[[Any], Any] | None = None, serialize_as_any: bool = False, ) -> dict[str, Any]: """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump @@ -267,16 +274,24 @@ def model_dump( Args: mode: The mode in which `to_python` should run. - If mode is 'json', the dictionary will only contain JSON serializable types. - If mode is 'python', the dictionary may contain any Python objects. - include: A list of fields to include in the output. - exclude: A list of fields to exclude from the output. + If mode is 'json', the output will only contain JSON serializable types. + If mode is 'python', the output may contain non-JSON-serializable Python objects. + include: A set of fields to include in the output. + exclude: A set of fields to exclude from the output. + context: Additional context to pass to the serializer. by_alias: Whether to use the field's alias in the dictionary key if defined. - exclude_unset: Whether to exclude fields that are unset or None from the output. - exclude_defaults: Whether to exclude fields that are set to their default value from the output. - exclude_none: Whether to exclude fields that have a value of `None` from the output. - round_trip: Whether to enable serialization and deserialization round-trip support. - warnings: Whether to log warnings when invalid fields are encountered. + exclude_unset: Whether to exclude fields that have not been explicitly set. + exclude_defaults: Whether to exclude fields that are set to their default value. + exclude_none: Whether to exclude fields that have a value of `None`. + exclude_computed_fields: Whether to exclude computed fields. + While this can be useful for round-tripping, it is usually recommended to use the dedicated + `round_trip` parameter instead. + round_trip: If True, dumped values should be valid as input for non-idempotent types such as Json[T]. + warnings: How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors, + "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError]. + fallback: A function to call when an unknown value is encountered. If not provided, + a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. + serialize_as_any: Whether to serialize fields with duck-typing serialization behavior. Returns: A dictionary representation of the model. @@ -291,31 +306,38 @@ def model_dump( raise ValueError("context is only supported in Pydantic v2") if serialize_as_any != False: raise ValueError("serialize_as_any is only supported in Pydantic v2") + if fallback is not None: + raise ValueError("fallback is only supported in Pydantic v2") + if exclude_computed_fields != False: + raise ValueError("exclude_computed_fields is only supported in Pydantic v2") dumped = super().dict( # pyright: ignore[reportDeprecated] include=include, exclude=exclude, - by_alias=by_alias, + by_alias=by_alias if by_alias is not None else False, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, exclude_none=exclude_none, ) - return cast(dict[str, Any], json_safe(dumped)) if mode == "json" else dumped + return cast("dict[str, Any]", json_safe(dumped)) if mode == "json" else dumped @override def model_dump_json( self, *, indent: int | None = None, + ensure_ascii: bool = False, include: IncEx | None = None, exclude: IncEx | None = None, - by_alias: bool = False, + context: Any | None = None, + by_alias: bool | None = None, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, + exclude_computed_fields: bool = False, round_trip: bool = False, warnings: bool | Literal["none", "warn", "error"] = True, - context: dict[str, Any] | None = None, + fallback: Callable[[Any], Any] | None = None, serialize_as_any: bool = False, ) -> str: """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json @@ -344,11 +366,17 @@ def model_dump_json( raise ValueError("context is only supported in Pydantic v2") if serialize_as_any != False: raise ValueError("serialize_as_any is only supported in Pydantic v2") + if fallback is not None: + raise ValueError("fallback is only supported in Pydantic v2") + if ensure_ascii != False: + raise ValueError("ensure_ascii is only supported in Pydantic v2") + if exclude_computed_fields != False: + raise ValueError("exclude_computed_fields is only supported in Pydantic v2") return super().json( # type: ignore[reportDeprecated] indent=indent, include=include, exclude=exclude, - by_alias=by_alias, + by_alias=by_alias if by_alias is not None else False, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, exclude_none=exclude_none, @@ -359,15 +387,32 @@ def _construct_field(value: object, field: FieldInfo, key: str) -> object: if value is None: return field_get_default(field) - if PYDANTIC_V2: - type_ = field.annotation - else: + if PYDANTIC_V1: type_ = cast(type, field.outer_type_) # type: ignore + else: + type_ = field.annotation # type: ignore if type_ is None: raise RuntimeError(f"Unexpected field type is None for {key}") - return construct_type(value=value, type_=type_) + return construct_type(value=value, type_=type_, metadata=getattr(field, "metadata", None)) + + +def _get_extra_fields_type(cls: type[pydantic.BaseModel]) -> type | None: + if PYDANTIC_V1: + # TODO + return None + + schema = cls.__pydantic_core_schema__ + if schema["type"] == "model": + fields = schema["schema"] + if fields["type"] == "model-fields": + extras = fields.get("extras_schema") + if extras and "cls" in extras: + # mypy can't narrow the type + return extras["cls"] # type: ignore[no-any-return] + + return None def is_basemodel(type_: type) -> bool: @@ -421,20 +466,28 @@ def construct_type_unchecked(*, value: object, type_: type[_T]) -> _T: return cast(_T, construct_type(value=value, type_=type_)) -def construct_type(*, value: object, type_: object) -> object: +def construct_type(*, value: object, type_: object, metadata: Optional[List[Any]] = None) -> object: """Loose coercion to the expected type with construction of nested values. If the given value does not match the expected type then it is returned as-is. """ + + # store a reference to the original type we were given before we extract any inner + # types so that we can properly resolve forward references in `TypeAliasType` annotations + original_type = None + # we allow `object` as the input type because otherwise, passing things like # `Literal['value']` will be reported as a type error by type checkers type_ = cast("type[object]", type_) if is_type_alias_type(type_): + original_type = type_ # type: ignore[unreachable] type_ = type_.__value__ # type: ignore[unreachable] # unwrap `Annotated[T, ...]` -> `T` - if is_annotated_type(type_): - meta: tuple[Any, ...] = get_args(type_)[1:] + if metadata is not None and len(metadata) > 0: + meta: tuple[Any, ...] = tuple(metadata) + elif is_annotated_type(type_): + meta = get_args(type_)[1:] type_ = extract_type_arg(type_, 0) else: meta = tuple() @@ -446,7 +499,7 @@ def construct_type(*, value: object, type_: object) -> object: if is_union(origin): try: - return validate_type(type_=cast("type[object]", type_), value=value) + return validate_type(type_=cast("type[object]", original_type or type_), value=value) except Exception: pass @@ -538,6 +591,9 @@ class CachedDiscriminatorType(Protocol): __discriminator__: DiscriminatorDetails +DISCRIMINATOR_CACHE: weakref.WeakKeyDictionary[type, DiscriminatorDetails] = weakref.WeakKeyDictionary() + + class DiscriminatorDetails: field_name: str """The name of the discriminator field in the variant class, e.g. @@ -580,8 +636,9 @@ def __init__( def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None: - if isinstance(union, CachedDiscriminatorType): - return union.__discriminator__ + cached = DISCRIMINATOR_CACHE.get(union) + if cached is not None: + return cached discriminator_field_name: str | None = None @@ -599,30 +656,30 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, for variant in get_args(union): variant = strip_annotated_type(variant) if is_basemodel_type(variant): - if PYDANTIC_V2: - field = _extract_field_schema_pv2(variant, discriminator_field_name) - if not field: + if PYDANTIC_V1: + field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast] + if not field_info: continue # Note: if one variant defines an alias then they all should - discriminator_alias = field.get("serialization_alias") - - field_schema = field["schema"] + discriminator_alias = field_info.alias - if field_schema["type"] == "literal": - for entry in cast("LiteralSchema", field_schema)["expected"]: + if (annotation := getattr(field_info, "annotation", None)) and is_literal_type(annotation): + for entry in get_args(annotation): if isinstance(entry, str): mapping[entry] = variant else: - field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast] - if not field_info: + field = _extract_field_schema_pv2(variant, discriminator_field_name) + if not field: continue # Note: if one variant defines an alias then they all should - discriminator_alias = field_info.alias + discriminator_alias = field.get("serialization_alias") + + field_schema = field["schema"] - if field_info.annotation and is_literal_type(field_info.annotation): - for entry in get_args(field_info.annotation): + if field_schema["type"] == "literal": + for entry in cast("LiteralSchema", field_schema)["expected"]: if isinstance(entry, str): mapping[entry] = variant @@ -634,21 +691,24 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, discriminator_field=discriminator_field_name, discriminator_alias=discriminator_alias, ) - cast(CachedDiscriminatorType, union).__discriminator__ = details + DISCRIMINATOR_CACHE.setdefault(union, details) return details def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None: schema = model.__pydantic_core_schema__ + if schema["type"] == "definitions": + schema = schema["schema"] + if schema["type"] != "model": return None + schema = cast("ModelSchema", schema) fields_schema = schema["schema"] if fields_schema["type"] != "model-fields": return None fields_schema = cast("ModelFieldsSchema", fields_schema) - field = fields_schema["fields"].get(field_name) if not field: return None @@ -672,7 +732,7 @@ def set_pydantic_config(typ: Any, config: pydantic.ConfigDict) -> None: setattr(typ, "__pydantic_config__", config) # noqa: B010 -# our use of subclasssing here causes weirdness for type checkers, +# our use of subclassing here causes weirdness for type checkers, # so we just pretend that we don't subclass if TYPE_CHECKING: GenericModel = BaseModel @@ -682,7 +742,7 @@ class GenericModel(BaseGenericModel, BaseModel): pass -if PYDANTIC_V2: +if not PYDANTIC_V1: from pydantic import TypeAdapter as _TypeAdapter _CachedTypeAdapter = cast("TypeAdapter[object]", lru_cache(maxsize=None)(_TypeAdapter)) @@ -729,6 +789,7 @@ class FinalRequestOptionsInput(TypedDict, total=False): idempotency_key: str json_data: Body extra_json: AnyMapping + follow_redirects: bool @final @@ -742,18 +803,19 @@ class FinalRequestOptions(pydantic.BaseModel): files: Union[HttpxRequestFiles, None] = None idempotency_key: Union[str, None] = None post_parser: Union[Callable[[Any], Any], NotGiven] = NotGiven() + follow_redirects: Union[bool, None] = None # It should be noted that we cannot use `json` here as that would override # a BaseModel method in an incompatible fashion. json_data: Union[Body, None] = None extra_json: Union[AnyMapping, None] = None - if PYDANTIC_V2: - model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) - else: + if PYDANTIC_V1: class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated] arbitrary_types_allowed: bool = True + else: + model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) def get_max_retries(self, max_retries: int) -> int: if isinstance(self.max_retries, NotGiven): @@ -786,9 +848,9 @@ def construct( # type: ignore key: strip_not_given(value) for key, value in values.items() } - if PYDANTIC_V2: - return super().model_construct(_fields_set, **kwargs) - return cast(FinalRequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated] + if PYDANTIC_V1: + return cast(FinalRequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated] + return super().model_construct(_fields_set, **kwargs) if not TYPE_CHECKING: # type checkers incorrectly complain about this assignment diff --git a/src/brainbase/_qs.py b/src/brainbase/_qs.py index 274320ca..ada6fd3f 100644 --- a/src/brainbase/_qs.py +++ b/src/brainbase/_qs.py @@ -4,7 +4,7 @@ from urllib.parse import parse_qs, urlencode from typing_extensions import Literal, get_args -from ._types import NOT_GIVEN, NotGiven, NotGivenOr +from ._types import NotGiven, not_given from ._utils import flatten _T = TypeVar("_T") @@ -41,8 +41,8 @@ def stringify( self, params: Params, *, - array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN, - nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN, + array_format: ArrayFormat | NotGiven = not_given, + nested_format: NestedFormat | NotGiven = not_given, ) -> str: return urlencode( self.stringify_items( @@ -56,8 +56,8 @@ def stringify_items( self, params: Params, *, - array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN, - nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN, + array_format: ArrayFormat | NotGiven = not_given, + nested_format: NestedFormat | NotGiven = not_given, ) -> list[tuple[str, str]]: opts = Options( qs=self, @@ -143,8 +143,8 @@ def __init__( self, qs: Querystring = _qs, *, - array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN, - nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN, + array_format: ArrayFormat | NotGiven = not_given, + nested_format: NestedFormat | NotGiven = not_given, ) -> None: self.array_format = qs.array_format if isinstance(array_format, NotGiven) else array_format self.nested_format = qs.nested_format if isinstance(nested_format, NotGiven) else nested_format diff --git a/src/brainbase/_response.py b/src/brainbase/_response.py index 5c0e8f01..2b33470a 100644 --- a/src/brainbase/_response.py +++ b/src/brainbase/_response.py @@ -233,7 +233,7 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T: # split is required to handle cases where additional information is included # in the response, e.g. application/json; charset=utf-8 content_type, *_ = response.headers.get("content-type", "*").split(";") - if content_type != "application/json": + if not content_type.endswith("json"): if is_basemodel(cast_to): try: data = response.json() diff --git a/src/brainbase/_streaming.py b/src/brainbase/_streaming.py index ec0b6267..d3bc9002 100644 --- a/src/brainbase/_streaming.py +++ b/src/brainbase/_streaming.py @@ -54,12 +54,12 @@ def __stream__(self) -> Iterator[_T]: process_data = self._client._process_response_data iterator = self._iter_events() - for sse in iterator: - yield process_data(data=sse.json(), cast_to=cast_to, response=response) - - # Ensure the entire stream is consumed - for _sse in iterator: - ... + try: + for sse in iterator: + yield process_data(data=sse.json(), cast_to=cast_to, response=response) + finally: + # Ensure the response is closed even if the consumer doesn't read all data + response.close() def __enter__(self) -> Self: return self @@ -118,12 +118,12 @@ async def __stream__(self) -> AsyncIterator[_T]: process_data = self._client._process_response_data iterator = self._iter_events() - async for sse in iterator: - yield process_data(data=sse.json(), cast_to=cast_to, response=response) - - # Ensure the entire stream is consumed - async for _sse in iterator: - ... + try: + async for sse in iterator: + yield process_data(data=sse.json(), cast_to=cast_to, response=response) + finally: + # Ensure the response is closed even if the consumer doesn't read all data + await response.aclose() async def __aenter__(self) -> Self: return self diff --git a/src/brainbase/_types.py b/src/brainbase/_types.py index 1a11c5e5..03b1fc5a 100644 --- a/src/brainbase/_types.py +++ b/src/brainbase/_types.py @@ -13,10 +13,21 @@ Mapping, TypeVar, Callable, + Iterator, Optional, Sequence, ) -from typing_extensions import Set, Literal, Protocol, TypeAlias, TypedDict, override, runtime_checkable +from typing_extensions import ( + Set, + Literal, + Protocol, + TypeAlias, + TypedDict, + SupportsIndex, + overload, + override, + runtime_checkable, +) import httpx import pydantic @@ -100,23 +111,27 @@ class RequestOptions(TypedDict, total=False): params: Query extra_json: AnyMapping idempotency_key: str + follow_redirects: bool # Sentinel class used until PEP 0661 is accepted class NotGiven: """ - A sentinel singleton class used to distinguish omitted keyword arguments - from those passed in with the value None (which may have different behavior). + For parameters with a meaningful None value, we need to distinguish between + the user explicitly passing None, and the user not passing the parameter at + all. + + User code shouldn't need to use not_given directly. For example: ```py - def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ... + def create(timeout: Timeout | None | NotGiven = not_given): ... - get(timeout=1) # 1s timeout - get(timeout=None) # No timeout - get() # Default timeout behavior, which may not be statically known at the method definition. + create(timeout=1) # 1s timeout + create(timeout=None) # No timeout + create() # Default timeout behavior ``` """ @@ -128,13 +143,14 @@ def __repr__(self) -> str: return "NOT_GIVEN" -NotGivenOr = Union[_T, NotGiven] +not_given = NotGiven() +# for backwards compatibility: NOT_GIVEN = NotGiven() class Omit: - """In certain situations you need to be able to represent a case where a default value has - to be explicitly removed and `None` is not an appropriate substitute, for example: + """ + To explicitly omit something from being sent in a request, use `omit`. ```py # as the default `Content-Type` header is `application/json` that will be sent @@ -144,8 +160,8 @@ class Omit: # to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983' client.post(..., headers={"Content-Type": "multipart/form-data"}) - # instead you can remove the default `application/json` header by passing Omit - client.post(..., headers={"Content-Type": Omit()}) + # instead you can remove the default `application/json` header by passing omit + client.post(..., headers={"Content-Type": omit}) ``` """ @@ -153,6 +169,9 @@ def __bool__(self) -> Literal[False]: return False +omit = Omit() + + @runtime_checkable class ModelBuilderProtocol(Protocol): @classmethod @@ -215,3 +234,28 @@ class _GenericAlias(Protocol): class HttpxSendArgs(TypedDict, total=False): auth: httpx.Auth + follow_redirects: bool + + +_T_co = TypeVar("_T_co", covariant=True) + + +if TYPE_CHECKING: + # This works because str.__contains__ does not accept object (either in typeshed or at runtime) + # https://github.com/hauntsaninja/useful_types/blob/5e9710f3875107d068e7679fd7fec9cfab0eff3b/useful_types/__init__.py#L285 + # + # Note: index() and count() methods are intentionally omitted to allow pyright to properly + # infer TypedDict types when dict literals are used in lists assigned to SequenceNotStr. + class SequenceNotStr(Protocol[_T_co]): + @overload + def __getitem__(self, index: SupportsIndex, /) -> _T_co: ... + @overload + def __getitem__(self, index: slice, /) -> Sequence[_T_co]: ... + def __contains__(self, value: object, /) -> bool: ... + def __len__(self) -> int: ... + def __iter__(self) -> Iterator[_T_co]: ... + def __reversed__(self) -> Iterator[_T_co]: ... +else: + # just point this to a normal `Sequence` at runtime to avoid having to special case + # deserializing our custom sequence type + SequenceNotStr = Sequence diff --git a/src/brainbase/_utils/__init__.py b/src/brainbase/_utils/__init__.py index d4fda26f..dc64e29a 100644 --- a/src/brainbase/_utils/__init__.py +++ b/src/brainbase/_utils/__init__.py @@ -10,7 +10,6 @@ lru_cache as lru_cache, is_mapping as is_mapping, is_tuple_t as is_tuple_t, - parse_date as parse_date, is_iterable as is_iterable, is_sequence as is_sequence, coerce_float as coerce_float, @@ -23,7 +22,6 @@ coerce_boolean as coerce_boolean, coerce_integer as coerce_integer, file_from_path as file_from_path, - parse_datetime as parse_datetime, strip_not_given as strip_not_given, deepcopy_minimal as deepcopy_minimal, get_async_library as get_async_library, @@ -32,12 +30,20 @@ maybe_coerce_boolean as maybe_coerce_boolean, maybe_coerce_integer as maybe_coerce_integer, ) +from ._compat import ( + get_args as get_args, + is_union as is_union, + get_origin as get_origin, + is_typeddict as is_typeddict, + is_literal_type as is_literal_type, +) from ._typing import ( is_list_type as is_list_type, is_union_type as is_union_type, extract_type_arg as extract_type_arg, is_iterable_type as is_iterable_type, is_required_type as is_required_type, + is_sequence_type as is_sequence_type, is_annotated_type as is_annotated_type, is_type_alias_type as is_type_alias_type, strip_annotated_type as strip_annotated_type, @@ -55,3 +61,4 @@ function_has_argument as function_has_argument, assert_signatures_in_sync as assert_signatures_in_sync, ) +from ._datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime diff --git a/src/brainbase/_utils/_compat.py b/src/brainbase/_utils/_compat.py new file mode 100644 index 00000000..dd703233 --- /dev/null +++ b/src/brainbase/_utils/_compat.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import sys +import typing_extensions +from typing import Any, Type, Union, Literal, Optional +from datetime import date, datetime +from typing_extensions import get_args as _get_args, get_origin as _get_origin + +from .._types import StrBytesIntFloat +from ._datetime_parse import parse_date as _parse_date, parse_datetime as _parse_datetime + +_LITERAL_TYPES = {Literal, typing_extensions.Literal} + + +def get_args(tp: type[Any]) -> tuple[Any, ...]: + return _get_args(tp) + + +def get_origin(tp: type[Any]) -> type[Any] | None: + return _get_origin(tp) + + +def is_union(tp: Optional[Type[Any]]) -> bool: + if sys.version_info < (3, 10): + return tp is Union # type: ignore[comparison-overlap] + else: + import types + + return tp is Union or tp is types.UnionType + + +def is_typeddict(tp: Type[Any]) -> bool: + return typing_extensions.is_typeddict(tp) + + +def is_literal_type(tp: Type[Any]) -> bool: + return get_origin(tp) in _LITERAL_TYPES + + +def parse_date(value: Union[date, StrBytesIntFloat]) -> date: + return _parse_date(value) + + +def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: + return _parse_datetime(value) diff --git a/src/brainbase/_utils/_datetime_parse.py b/src/brainbase/_utils/_datetime_parse.py new file mode 100644 index 00000000..7cb9d9e6 --- /dev/null +++ b/src/brainbase/_utils/_datetime_parse.py @@ -0,0 +1,136 @@ +""" +This file contains code from https://github.com/pydantic/pydantic/blob/main/pydantic/v1/datetime_parse.py +without the Pydantic v1 specific errors. +""" + +from __future__ import annotations + +import re +from typing import Dict, Union, Optional +from datetime import date, datetime, timezone, timedelta + +from .._types import StrBytesIntFloat + +date_expr = r"(?P\d{4})-(?P\d{1,2})-(?P\d{1,2})" +time_expr = ( + r"(?P\d{1,2}):(?P\d{1,2})" + r"(?::(?P\d{1,2})(?:\.(?P\d{1,6})\d{0,6})?)?" + r"(?PZ|[+-]\d{2}(?::?\d{2})?)?$" +) + +date_re = re.compile(f"{date_expr}$") +datetime_re = re.compile(f"{date_expr}[T ]{time_expr}") + + +EPOCH = datetime(1970, 1, 1) +# if greater than this, the number is in ms, if less than or equal it's in seconds +# (in seconds this is 11th October 2603, in ms it's 20th August 1970) +MS_WATERSHED = int(2e10) +# slightly more than datetime.max in ns - (datetime.max - EPOCH).total_seconds() * 1e9 +MAX_NUMBER = int(3e20) + + +def _get_numeric(value: StrBytesIntFloat, native_expected_type: str) -> Union[None, int, float]: + if isinstance(value, (int, float)): + return value + try: + return float(value) + except ValueError: + return None + except TypeError: + raise TypeError(f"invalid type; expected {native_expected_type}, string, bytes, int or float") from None + + +def _from_unix_seconds(seconds: Union[int, float]) -> datetime: + if seconds > MAX_NUMBER: + return datetime.max + elif seconds < -MAX_NUMBER: + return datetime.min + + while abs(seconds) > MS_WATERSHED: + seconds /= 1000 + dt = EPOCH + timedelta(seconds=seconds) + return dt.replace(tzinfo=timezone.utc) + + +def _parse_timezone(value: Optional[str]) -> Union[None, int, timezone]: + if value == "Z": + return timezone.utc + elif value is not None: + offset_mins = int(value[-2:]) if len(value) > 3 else 0 + offset = 60 * int(value[1:3]) + offset_mins + if value[0] == "-": + offset = -offset + return timezone(timedelta(minutes=offset)) + else: + return None + + +def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: + """ + Parse a datetime/int/float/string and return a datetime.datetime. + + This function supports time zone offsets. When the input contains one, + the output uses a timezone with a fixed offset from UTC. + + Raise ValueError if the input is well formatted but not a valid datetime. + Raise ValueError if the input isn't well formatted. + """ + if isinstance(value, datetime): + return value + + number = _get_numeric(value, "datetime") + if number is not None: + return _from_unix_seconds(number) + + if isinstance(value, bytes): + value = value.decode() + + assert not isinstance(value, (float, int)) + + match = datetime_re.match(value) + if match is None: + raise ValueError("invalid datetime format") + + kw = match.groupdict() + if kw["microsecond"]: + kw["microsecond"] = kw["microsecond"].ljust(6, "0") + + tzinfo = _parse_timezone(kw.pop("tzinfo")) + kw_: Dict[str, Union[None, int, timezone]] = {k: int(v) for k, v in kw.items() if v is not None} + kw_["tzinfo"] = tzinfo + + return datetime(**kw_) # type: ignore + + +def parse_date(value: Union[date, StrBytesIntFloat]) -> date: + """ + Parse a date/int/float/string and return a datetime.date. + + Raise ValueError if the input is well formatted but not a valid date. + Raise ValueError if the input isn't well formatted. + """ + if isinstance(value, date): + if isinstance(value, datetime): + return value.date() + else: + return value + + number = _get_numeric(value, "date") + if number is not None: + return _from_unix_seconds(number).date() + + if isinstance(value, bytes): + value = value.decode() + + assert not isinstance(value, (float, int)) + match = date_re.match(value) + if match is None: + raise ValueError("invalid date format") + + kw = {k: int(v) for k, v in match.groupdict().items()} + + try: + return date(**kw) + except ValueError: + raise ValueError("invalid date format") from None diff --git a/src/brainbase/_utils/_proxy.py b/src/brainbase/_utils/_proxy.py index ffd883e9..0f239a33 100644 --- a/src/brainbase/_utils/_proxy.py +++ b/src/brainbase/_utils/_proxy.py @@ -46,7 +46,10 @@ def __dir__(self) -> Iterable[str]: @property # type: ignore @override def __class__(self) -> type: # pyright: ignore - proxied = self.__get_proxied__() + try: + proxied = self.__get_proxied__() + except Exception: + return type(self) if issubclass(type(proxied), LazyProxy): return type(proxied) return proxied.__class__ diff --git a/src/brainbase/_utils/_resources_proxy.py b/src/brainbase/_utils/_resources_proxy.py new file mode 100644 index 00000000..1ba107b7 --- /dev/null +++ b/src/brainbase/_utils/_resources_proxy.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from typing import Any +from typing_extensions import override + +from ._proxy import LazyProxy + + +class ResourcesProxy(LazyProxy[Any]): + """A proxy for the `brainbase.resources` module. + + This is used so that we can lazily import `brainbase.resources` only when + needed *and* so that users can just import `brainbase` and reference `brainbase.resources` + """ + + @override + def __load__(self) -> Any: + import importlib + + mod = importlib.import_module("brainbase.resources") + return mod + + +resources = ResourcesProxy().__as_proxied__() diff --git a/src/brainbase/_utils/_sync.py b/src/brainbase/_utils/_sync.py index 8b3aaf2b..f6027c18 100644 --- a/src/brainbase/_utils/_sync.py +++ b/src/brainbase/_utils/_sync.py @@ -1,47 +1,34 @@ from __future__ import annotations -import sys import asyncio import functools -import contextvars -from typing import Any, TypeVar, Callable, Awaitable +from typing import TypeVar, Callable, Awaitable from typing_extensions import ParamSpec +import anyio +import sniffio +import anyio.to_thread + T_Retval = TypeVar("T_Retval") T_ParamSpec = ParamSpec("T_ParamSpec") -if sys.version_info >= (3, 9): - to_thread = asyncio.to_thread -else: - # backport of https://docs.python.org/3/library/asyncio-task.html#asyncio.to_thread - # for Python 3.8 support - async def to_thread( - func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs - ) -> Any: - """Asynchronously run function *func* in a separate thread. - - Any *args and **kwargs supplied for this function are directly passed - to *func*. Also, the current :class:`contextvars.Context` is propagated, - allowing context variables from the main thread to be accessed in the - separate thread. +async def to_thread( + func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs +) -> T_Retval: + if sniffio.current_async_library() == "asyncio": + return await asyncio.to_thread(func, *args, **kwargs) - Returns a coroutine that can be awaited to get the eventual result of *func*. - """ - loop = asyncio.events.get_running_loop() - ctx = contextvars.copy_context() - func_call = functools.partial(ctx.run, func, *args, **kwargs) - return await loop.run_in_executor(None, func_call) + return await anyio.to_thread.run_sync( + functools.partial(func, *args, **kwargs), + ) # inspired by `asyncer`, https://github.com/tiangolo/asyncer def asyncify(function: Callable[T_ParamSpec, T_Retval]) -> Callable[T_ParamSpec, Awaitable[T_Retval]]: """ Take a blocking function and create an async one that receives the same - positional and keyword arguments. For python version 3.9 and above, it uses - asyncio.to_thread to run the function in a separate thread. For python version - 3.8, it uses locally defined copy of the asyncio.to_thread function which was - introduced in python 3.9. + positional and keyword arguments. Usage: diff --git a/src/brainbase/_utils/_transform.py b/src/brainbase/_utils/_transform.py index a6b62cad..52075492 100644 --- a/src/brainbase/_utils/_transform.py +++ b/src/brainbase/_utils/_transform.py @@ -5,27 +5,31 @@ import pathlib from typing import Any, Mapping, TypeVar, cast from datetime import date, datetime -from typing_extensions import Literal, get_args, override, get_type_hints +from typing_extensions import Literal, get_args, override, get_type_hints as _get_type_hints import anyio import pydantic from ._utils import ( is_list, + is_given, + lru_cache, is_mapping, is_iterable, + is_sequence, ) from .._files import is_base64_file_input +from ._compat import get_origin, is_typeddict from ._typing import ( is_list_type, is_union_type, extract_type_arg, is_iterable_type, is_required_type, + is_sequence_type, is_annotated_type, strip_annotated_type, ) -from .._compat import model_dump, is_typeddict _T = TypeVar("_T") @@ -108,6 +112,7 @@ class Params(TypedDict, total=False): return cast(_T, transformed) +@lru_cache(maxsize=8096) def _get_annotated_type(type_: type) -> type | None: """If the given type is an `Annotated` type then it is returned, if not `None` is returned. @@ -126,7 +131,7 @@ def _get_annotated_type(type_: type) -> type | None: def _maybe_transform_key(key: str, type_: type) -> str: """Transform the given `data` based on the annotations provided in `type_`. - Note: this function only looks at `Annotated` types that contain `PropertInfo` metadata. + Note: this function only looks at `Annotated` types that contain `PropertyInfo` metadata. """ annotated_type = _get_annotated_type(type_) if annotated_type is None: @@ -142,6 +147,10 @@ def _maybe_transform_key(key: str, type_: type) -> str: return key +def _no_transform_needed(annotation: type) -> bool: + return annotation == float or annotation == int + + def _transform_recursive( data: object, *, @@ -160,18 +169,27 @@ def _transform_recursive( Defaults to the same value as the `annotation` argument. """ + from .._compat import model_dump + if inner_type is None: inner_type = annotation stripped_type = strip_annotated_type(inner_type) + origin = get_origin(stripped_type) or stripped_type if is_typeddict(stripped_type) and is_mapping(data): return _transform_typeddict(data, stripped_type) + if origin == dict and is_mapping(data): + items_type = get_args(stripped_type)[1] + return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()} + if ( # List[T] (is_list_type(stripped_type) and is_list(data)) # Iterable[T] or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) + # Sequence[T] + or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str)) ): # dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually # intended as an iterable, so we don't transform it. @@ -179,6 +197,15 @@ def _transform_recursive( return cast(object, data) inner_type = extract_type_arg(stripped_type, 0) + if _no_transform_needed(inner_type): + # for some types there is no need to transform anything, so we can get a small + # perf boost from skipping that work. + # + # but we still need to convert to a list to ensure the data is json-serializable + if is_list(data): + return data + return list(data) + return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] if is_union_type(stripped_type): @@ -240,6 +267,11 @@ def _transform_typeddict( result: dict[str, object] = {} annotations = get_type_hints(expected_type, include_extras=True) for key, value in data.items(): + if not is_given(value): + # we don't need to include omitted values here as they'll + # be stripped out before the request is sent anyway + continue + type_ = annotations.get(key) if type_ is None: # we do not have a type annotation for this field, leave it as is @@ -303,18 +335,27 @@ async def _async_transform_recursive( Defaults to the same value as the `annotation` argument. """ + from .._compat import model_dump + if inner_type is None: inner_type = annotation stripped_type = strip_annotated_type(inner_type) + origin = get_origin(stripped_type) or stripped_type if is_typeddict(stripped_type) and is_mapping(data): return await _async_transform_typeddict(data, stripped_type) + if origin == dict and is_mapping(data): + items_type = get_args(stripped_type)[1] + return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()} + if ( # List[T] (is_list_type(stripped_type) and is_list(data)) # Iterable[T] or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) + # Sequence[T] + or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str)) ): # dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually # intended as an iterable, so we don't transform it. @@ -322,6 +363,15 @@ async def _async_transform_recursive( return cast(object, data) inner_type = extract_type_arg(stripped_type, 0) + if _no_transform_needed(inner_type): + # for some types there is no need to transform anything, so we can get a small + # perf boost from skipping that work. + # + # but we still need to convert to a list to ensure the data is json-serializable + if is_list(data): + return data + return list(data) + return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] if is_union_type(stripped_type): @@ -383,6 +433,11 @@ async def _async_transform_typeddict( result: dict[str, object] = {} annotations = get_type_hints(expected_type, include_extras=True) for key, value in data.items(): + if not is_given(value): + # we don't need to include omitted values here as they'll + # be stripped out before the request is sent anyway + continue + type_ = annotations.get(key) if type_ is None: # we do not have a type annotation for this field, leave it as is @@ -390,3 +445,13 @@ async def _async_transform_typeddict( else: result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_) return result + + +@lru_cache(maxsize=8096) +def get_type_hints( + obj: Any, + globalns: dict[str, Any] | None = None, + localns: Mapping[str, Any] | None = None, + include_extras: bool = False, +) -> dict[str, Any]: + return _get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras) diff --git a/src/brainbase/_utils/_typing.py b/src/brainbase/_utils/_typing.py index 278749b1..193109f3 100644 --- a/src/brainbase/_utils/_typing.py +++ b/src/brainbase/_utils/_typing.py @@ -13,8 +13,9 @@ get_origin, ) +from ._utils import lru_cache from .._types import InheritsGeneric -from .._compat import is_union as _is_union +from ._compat import is_union as _is_union def is_annotated_type(typ: type) -> bool: @@ -25,6 +26,11 @@ def is_list_type(typ: type) -> bool: return (get_origin(typ) or typ) == list +def is_sequence_type(typ: type) -> bool: + origin = get_origin(typ) or typ + return origin == typing_extensions.Sequence or origin == typing.Sequence or origin == _c_abc.Sequence + + def is_iterable_type(typ: type) -> bool: """If the given type is `typing.Iterable[T]`""" origin = get_origin(typ) or typ @@ -66,6 +72,7 @@ def is_type_alias_type(tp: Any, /) -> TypeIs[typing_extensions.TypeAliasType]: # Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]] +@lru_cache(maxsize=8096) def strip_annotated_type(typ: type) -> type: if is_required_type(typ) or is_annotated_type(typ): return strip_annotated_type(cast(type, get_args(typ)[0])) @@ -108,7 +115,7 @@ class MyResponse(Foo[_T]): ``` """ cls = cast(object, get_origin(typ) or typ) - if cls in generic_bases: + if cls in generic_bases: # pyright: ignore[reportUnnecessaryContains] # we're given the class directly return extract_type_arg(typ, index) diff --git a/src/brainbase/_utils/_utils.py b/src/brainbase/_utils/_utils.py index e5811bba..eec7f4a1 100644 --- a/src/brainbase/_utils/_utils.py +++ b/src/brainbase/_utils/_utils.py @@ -21,8 +21,7 @@ import sniffio -from .._types import NotGiven, FileTypes, NotGivenOr, HeadersLike -from .._compat import parse_date as parse_date, parse_datetime as parse_datetime +from .._types import Omit, NotGiven, FileTypes, HeadersLike _T = TypeVar("_T") _TupleT = TypeVar("_TupleT", bound=Tuple[object, ...]) @@ -64,7 +63,7 @@ def _extract_items( try: key = path[index] except IndexError: - if isinstance(obj, NotGiven): + if not is_given(obj): # no value was provided - we can safely ignore return [] @@ -72,8 +71,16 @@ def _extract_items( from .._files import assert_is_file_content # We have exhausted the path, return the entry we found. - assert_is_file_content(obj, key=flattened_key) assert flattened_key is not None + + if is_list(obj): + files: list[tuple[str, FileTypes]] = [] + for entry in obj: + assert_is_file_content(entry, key=flattened_key + "[]" if flattened_key else "") + files.append((flattened_key + "[]", cast(FileTypes, entry))) + return files + + assert_is_file_content(obj, key=flattened_key) return [(flattened_key, cast(FileTypes, obj))] index += 1 @@ -119,14 +126,14 @@ def _extract_items( return [] -def is_given(obj: NotGivenOr[_T]) -> TypeGuard[_T]: - return not isinstance(obj, NotGiven) +def is_given(obj: _T | NotGiven | Omit) -> TypeGuard[_T]: + return not isinstance(obj, NotGiven) and not isinstance(obj, Omit) # Type safe methods for narrowing types with TypeVars. # The default narrowing for isinstance(obj, dict) is dict[unknown, unknown], # however this cause Pyright to rightfully report errors. As we know we don't -# care about the contained types we can safely use `object` in it's place. +# care about the contained types we can safely use `object` in its place. # # There are two separate functions defined, `is_*` and `is_*_t` for different use cases. # `is_*` is for when you're dealing with an unknown input diff --git a/src/brainbase/_version.py b/src/brainbase/_version.py index 00b5ae6a..791d7b56 100644 --- a/src/brainbase/_version.py +++ b/src/brainbase/_version.py @@ -1,4 +1,4 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. __title__ = "brainbase" -__version__ = "4.0.0" # x-release-please-version +__version__ = "4.1.0" # x-release-please-version diff --git a/src/brainbase/resources/workers/deployments/voice.py b/src/brainbase/resources/workers/deployments/voice.py index 241c6123..f0e101c3 100644 --- a/src/brainbase/resources/workers/deployments/voice.py +++ b/src/brainbase/resources/workers/deployments/voice.py @@ -4,11 +4,8 @@ import httpx -from ...._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven -from ...._utils import ( - maybe_transform, - async_maybe_transform, -) +from ...._types import Body, Omit, Query, Headers, NoneType, NotGiven, omit, not_given +from ...._utils import maybe_transform, async_maybe_transform from ...._compat import cached_property from ...._resource import SyncAPIResource, AsyncAPIResource from ...._response import ( @@ -19,10 +16,8 @@ ) from ...._base_client import make_request_options from ....types.workers.deployments import voice_create_params, voice_update_params +from ....types.workers.deployments.voice_deployment import VoiceDeployment from ....types.workers.deployments.voice_list_response import VoiceListResponse -from ....types.workers.deployments.voice_create_response import VoiceCreateResponse -from ....types.workers.deployments.voice_update_response import VoiceUpdateResponse -from ....types.workers.deployments.voice_retrieve_response import VoiceRetrieveResponse __all__ = ["VoiceResource", "AsyncVoiceResource"] @@ -52,16 +47,16 @@ def create( worker_id: str, *, name: str, - phone_number: str | NotGiven = NOT_GIVEN, - voice_id: str | NotGiven = NOT_GIVEN, - voice_provider: str | NotGiven = NOT_GIVEN, + phone_number: str | Omit = omit, + voice_id: str | Omit = omit, + voice_provider: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> VoiceCreateResponse: + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> VoiceDeployment: """ Create a new voice deployment @@ -98,7 +93,7 @@ def create( options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=VoiceCreateResponse, + cast_to=VoiceDeployment, ) def retrieve( @@ -111,8 +106,8 @@ def retrieve( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> VoiceRetrieveResponse: + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> VoiceDeployment: """ Get a single voice deployment @@ -134,7 +129,7 @@ def retrieve( options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=VoiceRetrieveResponse, + cast_to=VoiceDeployment, ) def update( @@ -143,16 +138,16 @@ def update( *, worker_id: str, name: str, - phone_number: str | NotGiven = NOT_GIVEN, - voice_id: str | NotGiven = NOT_GIVEN, - voice_provider: str | NotGiven = NOT_GIVEN, + phone_number: str | Omit = omit, + voice_id: str | Omit = omit, + voice_provider: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> VoiceUpdateResponse: + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> VoiceDeployment: """ Update a voice deployment @@ -191,7 +186,7 @@ def update( options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=VoiceUpdateResponse, + cast_to=VoiceDeployment, ) def list( @@ -203,7 +198,7 @@ def list( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> VoiceListResponse: """ Get all voice deployments for a worker @@ -237,7 +232,7 @@ def delete( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> None: """ Delete a voice deployment @@ -290,16 +285,16 @@ async def create( worker_id: str, *, name: str, - phone_number: str | NotGiven = NOT_GIVEN, - voice_id: str | NotGiven = NOT_GIVEN, - voice_provider: str | NotGiven = NOT_GIVEN, + phone_number: str | Omit = omit, + voice_id: str | Omit = omit, + voice_provider: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> VoiceCreateResponse: + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> VoiceDeployment: """ Create a new voice deployment @@ -336,7 +331,7 @@ async def create( options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=VoiceCreateResponse, + cast_to=VoiceDeployment, ) async def retrieve( @@ -349,8 +344,8 @@ async def retrieve( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> VoiceRetrieveResponse: + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> VoiceDeployment: """ Get a single voice deployment @@ -372,7 +367,7 @@ async def retrieve( options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=VoiceRetrieveResponse, + cast_to=VoiceDeployment, ) async def update( @@ -381,16 +376,16 @@ async def update( *, worker_id: str, name: str, - phone_number: str | NotGiven = NOT_GIVEN, - voice_id: str | NotGiven = NOT_GIVEN, - voice_provider: str | NotGiven = NOT_GIVEN, + phone_number: str | Omit = omit, + voice_id: str | Omit = omit, + voice_provider: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> VoiceUpdateResponse: + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> VoiceDeployment: """ Update a voice deployment @@ -429,7 +424,7 @@ async def update( options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=VoiceUpdateResponse, + cast_to=VoiceDeployment, ) async def list( @@ -441,7 +436,7 @@ async def list( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> VoiceListResponse: """ Get all voice deployments for a worker @@ -475,7 +470,7 @@ async def delete( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> None: """ Delete a voice deployment diff --git a/src/brainbase/resources/workers/flows.py b/src/brainbase/resources/workers/flows.py index 0d2ef89f..7bb20bc7 100644 --- a/src/brainbase/resources/workers/flows.py +++ b/src/brainbase/resources/workers/flows.py @@ -4,11 +4,8 @@ import httpx -from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven -from ..._utils import ( - maybe_transform, - async_maybe_transform, -) +from ..._types import Body, Omit, Query, Headers, NoneType, NotGiven, omit, not_given +from ..._utils import maybe_transform, async_maybe_transform from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource from ..._response import ( @@ -53,13 +50,13 @@ def create( *, code: str, name: str, - label: str | NotGiven = NOT_GIVEN, + label: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> FlowCreateResponse: """ Create a new flow @@ -107,7 +104,7 @@ def retrieve( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> FlowRetrieveResponse: """ Get a single flow @@ -138,15 +135,15 @@ def update( flow_id: str, *, worker_id: str, - code: str | NotGiven = NOT_GIVEN, - label: str | NotGiven = NOT_GIVEN, - name: str | NotGiven = NOT_GIVEN, + code: str | Omit = omit, + label: str | Omit = omit, + name: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> FlowUpdateResponse: """ Update a flow @@ -195,7 +192,7 @@ def list( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> FlowListResponse: """ Get all flows for a worker @@ -229,7 +226,7 @@ def delete( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> None: """ Delete a flow @@ -283,13 +280,13 @@ async def create( *, code: str, name: str, - label: str | NotGiven = NOT_GIVEN, + label: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> FlowCreateResponse: """ Create a new flow @@ -337,7 +334,7 @@ async def retrieve( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> FlowRetrieveResponse: """ Get a single flow @@ -368,15 +365,15 @@ async def update( flow_id: str, *, worker_id: str, - code: str | NotGiven = NOT_GIVEN, - label: str | NotGiven = NOT_GIVEN, - name: str | NotGiven = NOT_GIVEN, + code: str | Omit = omit, + label: str | Omit = omit, + name: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> FlowUpdateResponse: """ Update a flow @@ -425,7 +422,7 @@ async def list( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> FlowListResponse: """ Get all flows for a worker @@ -459,7 +456,7 @@ async def delete( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> None: """ Delete a flow diff --git a/src/brainbase/resources/workers/workers.py b/src/brainbase/resources/workers/workers.py index 9c160fbb..30d1f790 100644 --- a/src/brainbase/resources/workers/workers.py +++ b/src/brainbase/resources/workers/workers.py @@ -13,11 +13,8 @@ AsyncFlowsResourceWithStreamingResponse, ) from ...types import worker_create_params, worker_update_params -from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven -from ..._utils import ( - maybe_transform, - async_maybe_transform, -) +from ..._types import Body, Omit, Query, Headers, NoneType, NotGiven, omit, not_given +from ..._utils import maybe_transform, async_maybe_transform from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource from ..._response import ( @@ -75,13 +72,13 @@ def create( self, *, name: str, - description: str | NotGiven = NOT_GIVEN, + description: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> WorkerCreateResponse: """ Create a new worker @@ -123,7 +120,7 @@ def retrieve( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> WorkerRetrieveResponse: """ Get a single worker @@ -151,14 +148,14 @@ def update( self, id: str, *, - description: str | NotGiven = NOT_GIVEN, - name: str | NotGiven = NOT_GIVEN, + description: str | Omit = omit, + name: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> WorkerUpdateResponse: """ Update a worker @@ -201,7 +198,7 @@ def list( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> WorkerListResponse: """Get all workers for the team""" return self._get( @@ -221,7 +218,7 @@ def delete( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> None: """ Delete a worker @@ -279,13 +276,13 @@ async def create( self, *, name: str, - description: str | NotGiven = NOT_GIVEN, + description: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> WorkerCreateResponse: """ Create a new worker @@ -327,7 +324,7 @@ async def retrieve( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> WorkerRetrieveResponse: """ Get a single worker @@ -355,14 +352,14 @@ async def update( self, id: str, *, - description: str | NotGiven = NOT_GIVEN, - name: str | NotGiven = NOT_GIVEN, + description: str | Omit = omit, + name: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> WorkerUpdateResponse: """ Update a worker @@ -405,7 +402,7 @@ async def list( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> WorkerListResponse: """Get all workers for the team""" return await self._get( @@ -425,7 +422,7 @@ async def delete( extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> None: """ Delete a worker diff --git a/src/brainbase/types/workers/deployments/__init__.py b/src/brainbase/types/workers/deployments/__init__.py index dbdb876a..e5947326 100644 --- a/src/brainbase/types/workers/deployments/__init__.py +++ b/src/brainbase/types/workers/deployments/__init__.py @@ -2,9 +2,7 @@ from __future__ import annotations +from .voice_deployment import VoiceDeployment as VoiceDeployment from .voice_create_params import VoiceCreateParams as VoiceCreateParams from .voice_list_response import VoiceListResponse as VoiceListResponse from .voice_update_params import VoiceUpdateParams as VoiceUpdateParams -from .voice_create_response import VoiceCreateResponse as VoiceCreateResponse -from .voice_update_response import VoiceUpdateResponse as VoiceUpdateResponse -from .voice_retrieve_response import VoiceRetrieveResponse as VoiceRetrieveResponse diff --git a/src/brainbase/types/workers/deployments/voice_create_response.py b/src/brainbase/types/workers/deployments/voice_deployment.py similarity index 88% rename from src/brainbase/types/workers/deployments/voice_create_response.py rename to src/brainbase/types/workers/deployments/voice_deployment.py index 2c9a7960..61ea6936 100644 --- a/src/brainbase/types/workers/deployments/voice_create_response.py +++ b/src/brainbase/types/workers/deployments/voice_deployment.py @@ -6,10 +6,10 @@ from ...._models import BaseModel -__all__ = ["VoiceCreateResponse"] +__all__ = ["VoiceDeployment"] -class VoiceCreateResponse(BaseModel): +class VoiceDeployment(BaseModel): id: str delegate_aux_deployments_id: Optional[str] = FieldInfo(alias="delegate_aux_deploymentsId", default=None) diff --git a/src/brainbase/types/workers/deployments/voice_list_response.py b/src/brainbase/types/workers/deployments/voice_list_response.py index d54f167d..d714d5dc 100644 --- a/src/brainbase/types/workers/deployments/voice_list_response.py +++ b/src/brainbase/types/workers/deployments/voice_list_response.py @@ -1,25 +1,10 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from typing import List, Optional +from typing import List from typing_extensions import TypeAlias -from pydantic import Field as FieldInfo +from .voice_deployment import VoiceDeployment -from ...._models import BaseModel +__all__ = ["VoiceListResponse"] -__all__ = ["VoiceListResponse", "VoiceListResponseItem"] - - -class VoiceListResponseItem(BaseModel): - id: str - - delegate_aux_deployments_id: Optional[str] = FieldInfo(alias="delegate_aux_deploymentsId", default=None) - - phone_number: Optional[str] = FieldInfo(alias="phoneNumber", default=None) - - voice_id: Optional[str] = FieldInfo(alias="voiceId", default=None) - - voice_provider: Optional[str] = FieldInfo(alias="voiceProvider", default=None) - - -VoiceListResponse: TypeAlias = List[VoiceListResponseItem] +VoiceListResponse: TypeAlias = List[VoiceDeployment] diff --git a/src/brainbase/types/workers/deployments/voice_retrieve_response.py b/src/brainbase/types/workers/deployments/voice_retrieve_response.py deleted file mode 100644 index 15fc4abe..00000000 --- a/src/brainbase/types/workers/deployments/voice_retrieve_response.py +++ /dev/null @@ -1,21 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from typing import Optional - -from pydantic import Field as FieldInfo - -from ...._models import BaseModel - -__all__ = ["VoiceRetrieveResponse"] - - -class VoiceRetrieveResponse(BaseModel): - id: str - - delegate_aux_deployments_id: Optional[str] = FieldInfo(alias="delegate_aux_deploymentsId", default=None) - - phone_number: Optional[str] = FieldInfo(alias="phoneNumber", default=None) - - voice_id: Optional[str] = FieldInfo(alias="voiceId", default=None) - - voice_provider: Optional[str] = FieldInfo(alias="voiceProvider", default=None) diff --git a/src/brainbase/types/workers/deployments/voice_update_response.py b/src/brainbase/types/workers/deployments/voice_update_response.py deleted file mode 100644 index 67d8460f..00000000 --- a/src/brainbase/types/workers/deployments/voice_update_response.py +++ /dev/null @@ -1,21 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from typing import Optional - -from pydantic import Field as FieldInfo - -from ...._models import BaseModel - -__all__ = ["VoiceUpdateResponse"] - - -class VoiceUpdateResponse(BaseModel): - id: str - - delegate_aux_deployments_id: Optional[str] = FieldInfo(alias="delegate_aux_deploymentsId", default=None) - - phone_number: Optional[str] = FieldInfo(alias="phoneNumber", default=None) - - voice_id: Optional[str] = FieldInfo(alias="voiceId", default=None) - - voice_provider: Optional[str] = FieldInfo(alias="voiceProvider", default=None) diff --git a/tests/api_resources/test_workers.py b/tests/api_resources/test_workers.py index 28d15c32..2ae82aa6 100644 --- a/tests/api_resources/test_workers.py +++ b/tests/api_resources/test_workers.py @@ -22,7 +22,7 @@ class TestWorkers: parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_method_create(self, client: Brainbase) -> None: worker = client.workers.create( @@ -30,7 +30,7 @@ def test_method_create(self, client: Brainbase) -> None: ) assert_matches_type(WorkerCreateResponse, worker, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_method_create_with_all_params(self, client: Brainbase) -> None: worker = client.workers.create( @@ -39,7 +39,7 @@ def test_method_create_with_all_params(self, client: Brainbase) -> None: ) assert_matches_type(WorkerCreateResponse, worker, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_raw_response_create(self, client: Brainbase) -> None: response = client.workers.with_raw_response.create( @@ -51,7 +51,7 @@ def test_raw_response_create(self, client: Brainbase) -> None: worker = response.parse() assert_matches_type(WorkerCreateResponse, worker, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_streaming_response_create(self, client: Brainbase) -> None: with client.workers.with_streaming_response.create( @@ -65,7 +65,7 @@ def test_streaming_response_create(self, client: Brainbase) -> None: assert cast(Any, response.is_closed) is True - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_method_retrieve(self, client: Brainbase) -> None: worker = client.workers.retrieve( @@ -73,7 +73,7 @@ def test_method_retrieve(self, client: Brainbase) -> None: ) assert_matches_type(WorkerRetrieveResponse, worker, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_raw_response_retrieve(self, client: Brainbase) -> None: response = client.workers.with_raw_response.retrieve( @@ -85,7 +85,7 @@ def test_raw_response_retrieve(self, client: Brainbase) -> None: worker = response.parse() assert_matches_type(WorkerRetrieveResponse, worker, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_streaming_response_retrieve(self, client: Brainbase) -> None: with client.workers.with_streaming_response.retrieve( @@ -99,7 +99,7 @@ def test_streaming_response_retrieve(self, client: Brainbase) -> None: assert cast(Any, response.is_closed) is True - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_path_params_retrieve(self, client: Brainbase) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): @@ -107,7 +107,7 @@ def test_path_params_retrieve(self, client: Brainbase) -> None: "", ) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_method_update(self, client: Brainbase) -> None: worker = client.workers.update( @@ -115,7 +115,7 @@ def test_method_update(self, client: Brainbase) -> None: ) assert_matches_type(WorkerUpdateResponse, worker, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_method_update_with_all_params(self, client: Brainbase) -> None: worker = client.workers.update( @@ -125,7 +125,7 @@ def test_method_update_with_all_params(self, client: Brainbase) -> None: ) assert_matches_type(WorkerUpdateResponse, worker, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_raw_response_update(self, client: Brainbase) -> None: response = client.workers.with_raw_response.update( @@ -137,7 +137,7 @@ def test_raw_response_update(self, client: Brainbase) -> None: worker = response.parse() assert_matches_type(WorkerUpdateResponse, worker, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_streaming_response_update(self, client: Brainbase) -> None: with client.workers.with_streaming_response.update( @@ -151,7 +151,7 @@ def test_streaming_response_update(self, client: Brainbase) -> None: assert cast(Any, response.is_closed) is True - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_path_params_update(self, client: Brainbase) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): @@ -159,13 +159,13 @@ def test_path_params_update(self, client: Brainbase) -> None: id="", ) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_method_list(self, client: Brainbase) -> None: worker = client.workers.list() assert_matches_type(WorkerListResponse, worker, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_raw_response_list(self, client: Brainbase) -> None: response = client.workers.with_raw_response.list() @@ -175,7 +175,7 @@ def test_raw_response_list(self, client: Brainbase) -> None: worker = response.parse() assert_matches_type(WorkerListResponse, worker, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_streaming_response_list(self, client: Brainbase) -> None: with client.workers.with_streaming_response.list() as response: @@ -187,7 +187,7 @@ def test_streaming_response_list(self, client: Brainbase) -> None: assert cast(Any, response.is_closed) is True - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_method_delete(self, client: Brainbase) -> None: worker = client.workers.delete( @@ -195,7 +195,7 @@ def test_method_delete(self, client: Brainbase) -> None: ) assert worker is None - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_raw_response_delete(self, client: Brainbase) -> None: response = client.workers.with_raw_response.delete( @@ -207,7 +207,7 @@ def test_raw_response_delete(self, client: Brainbase) -> None: worker = response.parse() assert worker is None - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_streaming_response_delete(self, client: Brainbase) -> None: with client.workers.with_streaming_response.delete( @@ -221,7 +221,7 @@ def test_streaming_response_delete(self, client: Brainbase) -> None: assert cast(Any, response.is_closed) is True - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_path_params_delete(self, client: Brainbase) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): @@ -231,9 +231,11 @@ def test_path_params_delete(self, client: Brainbase) -> None: class TestAsyncWorkers: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_method_create(self, async_client: AsyncBrainbase) -> None: worker = await async_client.workers.create( @@ -241,7 +243,7 @@ async def test_method_create(self, async_client: AsyncBrainbase) -> None: ) assert_matches_type(WorkerCreateResponse, worker, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_method_create_with_all_params(self, async_client: AsyncBrainbase) -> None: worker = await async_client.workers.create( @@ -250,7 +252,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncBrainbase) ) assert_matches_type(WorkerCreateResponse, worker, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_raw_response_create(self, async_client: AsyncBrainbase) -> None: response = await async_client.workers.with_raw_response.create( @@ -262,7 +264,7 @@ async def test_raw_response_create(self, async_client: AsyncBrainbase) -> None: worker = await response.parse() assert_matches_type(WorkerCreateResponse, worker, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_streaming_response_create(self, async_client: AsyncBrainbase) -> None: async with async_client.workers.with_streaming_response.create( @@ -276,7 +278,7 @@ async def test_streaming_response_create(self, async_client: AsyncBrainbase) -> assert cast(Any, response.is_closed) is True - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_method_retrieve(self, async_client: AsyncBrainbase) -> None: worker = await async_client.workers.retrieve( @@ -284,7 +286,7 @@ async def test_method_retrieve(self, async_client: AsyncBrainbase) -> None: ) assert_matches_type(WorkerRetrieveResponse, worker, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_raw_response_retrieve(self, async_client: AsyncBrainbase) -> None: response = await async_client.workers.with_raw_response.retrieve( @@ -296,7 +298,7 @@ async def test_raw_response_retrieve(self, async_client: AsyncBrainbase) -> None worker = await response.parse() assert_matches_type(WorkerRetrieveResponse, worker, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_streaming_response_retrieve(self, async_client: AsyncBrainbase) -> None: async with async_client.workers.with_streaming_response.retrieve( @@ -310,7 +312,7 @@ async def test_streaming_response_retrieve(self, async_client: AsyncBrainbase) - assert cast(Any, response.is_closed) is True - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_path_params_retrieve(self, async_client: AsyncBrainbase) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): @@ -318,7 +320,7 @@ async def test_path_params_retrieve(self, async_client: AsyncBrainbase) -> None: "", ) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_method_update(self, async_client: AsyncBrainbase) -> None: worker = await async_client.workers.update( @@ -326,7 +328,7 @@ async def test_method_update(self, async_client: AsyncBrainbase) -> None: ) assert_matches_type(WorkerUpdateResponse, worker, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_method_update_with_all_params(self, async_client: AsyncBrainbase) -> None: worker = await async_client.workers.update( @@ -336,7 +338,7 @@ async def test_method_update_with_all_params(self, async_client: AsyncBrainbase) ) assert_matches_type(WorkerUpdateResponse, worker, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_raw_response_update(self, async_client: AsyncBrainbase) -> None: response = await async_client.workers.with_raw_response.update( @@ -348,7 +350,7 @@ async def test_raw_response_update(self, async_client: AsyncBrainbase) -> None: worker = await response.parse() assert_matches_type(WorkerUpdateResponse, worker, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_streaming_response_update(self, async_client: AsyncBrainbase) -> None: async with async_client.workers.with_streaming_response.update( @@ -362,7 +364,7 @@ async def test_streaming_response_update(self, async_client: AsyncBrainbase) -> assert cast(Any, response.is_closed) is True - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_path_params_update(self, async_client: AsyncBrainbase) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): @@ -370,13 +372,13 @@ async def test_path_params_update(self, async_client: AsyncBrainbase) -> None: id="", ) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_method_list(self, async_client: AsyncBrainbase) -> None: worker = await async_client.workers.list() assert_matches_type(WorkerListResponse, worker, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_raw_response_list(self, async_client: AsyncBrainbase) -> None: response = await async_client.workers.with_raw_response.list() @@ -386,7 +388,7 @@ async def test_raw_response_list(self, async_client: AsyncBrainbase) -> None: worker = await response.parse() assert_matches_type(WorkerListResponse, worker, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_streaming_response_list(self, async_client: AsyncBrainbase) -> None: async with async_client.workers.with_streaming_response.list() as response: @@ -398,7 +400,7 @@ async def test_streaming_response_list(self, async_client: AsyncBrainbase) -> No assert cast(Any, response.is_closed) is True - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_method_delete(self, async_client: AsyncBrainbase) -> None: worker = await async_client.workers.delete( @@ -406,7 +408,7 @@ async def test_method_delete(self, async_client: AsyncBrainbase) -> None: ) assert worker is None - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_raw_response_delete(self, async_client: AsyncBrainbase) -> None: response = await async_client.workers.with_raw_response.delete( @@ -418,7 +420,7 @@ async def test_raw_response_delete(self, async_client: AsyncBrainbase) -> None: worker = await response.parse() assert worker is None - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_streaming_response_delete(self, async_client: AsyncBrainbase) -> None: async with async_client.workers.with_streaming_response.delete( @@ -432,7 +434,7 @@ async def test_streaming_response_delete(self, async_client: AsyncBrainbase) -> assert cast(Any, response.is_closed) is True - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_path_params_delete(self, async_client: AsyncBrainbase) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): diff --git a/tests/api_resources/workers/deployments/test_voice.py b/tests/api_resources/workers/deployments/test_voice.py index 7bdabc59..f8b0b122 100644 --- a/tests/api_resources/workers/deployments/test_voice.py +++ b/tests/api_resources/workers/deployments/test_voice.py @@ -10,10 +10,8 @@ from brainbase import Brainbase, AsyncBrainbase from tests.utils import assert_matches_type from brainbase.types.workers.deployments import ( + VoiceDeployment, VoiceListResponse, - VoiceCreateResponse, - VoiceUpdateResponse, - VoiceRetrieveResponse, ) base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -22,16 +20,16 @@ class TestVoice: parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_method_create(self, client: Brainbase) -> None: voice = client.workers.deployments.voice.create( worker_id="workerId", name="name", ) - assert_matches_type(VoiceCreateResponse, voice, path=["response"]) + assert_matches_type(VoiceDeployment, voice, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_method_create_with_all_params(self, client: Brainbase) -> None: voice = client.workers.deployments.voice.create( @@ -41,9 +39,9 @@ def test_method_create_with_all_params(self, client: Brainbase) -> None: voice_id="voiceId", voice_provider="voiceProvider", ) - assert_matches_type(VoiceCreateResponse, voice, path=["response"]) + assert_matches_type(VoiceDeployment, voice, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_raw_response_create(self, client: Brainbase) -> None: response = client.workers.deployments.voice.with_raw_response.create( @@ -54,9 +52,9 @@ def test_raw_response_create(self, client: Brainbase) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" voice = response.parse() - assert_matches_type(VoiceCreateResponse, voice, path=["response"]) + assert_matches_type(VoiceDeployment, voice, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_streaming_response_create(self, client: Brainbase) -> None: with client.workers.deployments.voice.with_streaming_response.create( @@ -67,11 +65,11 @@ def test_streaming_response_create(self, client: Brainbase) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" voice = response.parse() - assert_matches_type(VoiceCreateResponse, voice, path=["response"]) + assert_matches_type(VoiceDeployment, voice, path=["response"]) assert cast(Any, response.is_closed) is True - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_path_params_create(self, client: Brainbase) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"): @@ -80,16 +78,16 @@ def test_path_params_create(self, client: Brainbase) -> None: name="name", ) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_method_retrieve(self, client: Brainbase) -> None: voice = client.workers.deployments.voice.retrieve( deployment_id="deploymentId", worker_id="workerId", ) - assert_matches_type(VoiceRetrieveResponse, voice, path=["response"]) + assert_matches_type(VoiceDeployment, voice, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_raw_response_retrieve(self, client: Brainbase) -> None: response = client.workers.deployments.voice.with_raw_response.retrieve( @@ -100,9 +98,9 @@ def test_raw_response_retrieve(self, client: Brainbase) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" voice = response.parse() - assert_matches_type(VoiceRetrieveResponse, voice, path=["response"]) + assert_matches_type(VoiceDeployment, voice, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_streaming_response_retrieve(self, client: Brainbase) -> None: with client.workers.deployments.voice.with_streaming_response.retrieve( @@ -113,11 +111,11 @@ def test_streaming_response_retrieve(self, client: Brainbase) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" voice = response.parse() - assert_matches_type(VoiceRetrieveResponse, voice, path=["response"]) + assert_matches_type(VoiceDeployment, voice, path=["response"]) assert cast(Any, response.is_closed) is True - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_path_params_retrieve(self, client: Brainbase) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"): @@ -132,7 +130,7 @@ def test_path_params_retrieve(self, client: Brainbase) -> None: worker_id="workerId", ) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_method_update(self, client: Brainbase) -> None: voice = client.workers.deployments.voice.update( @@ -140,9 +138,9 @@ def test_method_update(self, client: Brainbase) -> None: worker_id="workerId", name="name", ) - assert_matches_type(VoiceUpdateResponse, voice, path=["response"]) + assert_matches_type(VoiceDeployment, voice, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_method_update_with_all_params(self, client: Brainbase) -> None: voice = client.workers.deployments.voice.update( @@ -153,9 +151,9 @@ def test_method_update_with_all_params(self, client: Brainbase) -> None: voice_id="voiceId", voice_provider="voiceProvider", ) - assert_matches_type(VoiceUpdateResponse, voice, path=["response"]) + assert_matches_type(VoiceDeployment, voice, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_raw_response_update(self, client: Brainbase) -> None: response = client.workers.deployments.voice.with_raw_response.update( @@ -167,9 +165,9 @@ def test_raw_response_update(self, client: Brainbase) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" voice = response.parse() - assert_matches_type(VoiceUpdateResponse, voice, path=["response"]) + assert_matches_type(VoiceDeployment, voice, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_streaming_response_update(self, client: Brainbase) -> None: with client.workers.deployments.voice.with_streaming_response.update( @@ -181,11 +179,11 @@ def test_streaming_response_update(self, client: Brainbase) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" voice = response.parse() - assert_matches_type(VoiceUpdateResponse, voice, path=["response"]) + assert_matches_type(VoiceDeployment, voice, path=["response"]) assert cast(Any, response.is_closed) is True - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_path_params_update(self, client: Brainbase) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"): @@ -202,7 +200,7 @@ def test_path_params_update(self, client: Brainbase) -> None: name="name", ) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_method_list(self, client: Brainbase) -> None: voice = client.workers.deployments.voice.list( @@ -210,7 +208,7 @@ def test_method_list(self, client: Brainbase) -> None: ) assert_matches_type(VoiceListResponse, voice, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_raw_response_list(self, client: Brainbase) -> None: response = client.workers.deployments.voice.with_raw_response.list( @@ -222,7 +220,7 @@ def test_raw_response_list(self, client: Brainbase) -> None: voice = response.parse() assert_matches_type(VoiceListResponse, voice, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_streaming_response_list(self, client: Brainbase) -> None: with client.workers.deployments.voice.with_streaming_response.list( @@ -236,7 +234,7 @@ def test_streaming_response_list(self, client: Brainbase) -> None: assert cast(Any, response.is_closed) is True - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_path_params_list(self, client: Brainbase) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"): @@ -244,7 +242,7 @@ def test_path_params_list(self, client: Brainbase) -> None: "", ) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_method_delete(self, client: Brainbase) -> None: voice = client.workers.deployments.voice.delete( @@ -253,7 +251,7 @@ def test_method_delete(self, client: Brainbase) -> None: ) assert voice is None - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_raw_response_delete(self, client: Brainbase) -> None: response = client.workers.deployments.voice.with_raw_response.delete( @@ -266,7 +264,7 @@ def test_raw_response_delete(self, client: Brainbase) -> None: voice = response.parse() assert voice is None - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_streaming_response_delete(self, client: Brainbase) -> None: with client.workers.deployments.voice.with_streaming_response.delete( @@ -281,7 +279,7 @@ def test_streaming_response_delete(self, client: Brainbase) -> None: assert cast(Any, response.is_closed) is True - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_path_params_delete(self, client: Brainbase) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"): @@ -298,18 +296,20 @@ def test_path_params_delete(self, client: Brainbase) -> None: class TestAsyncVoice: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_method_create(self, async_client: AsyncBrainbase) -> None: voice = await async_client.workers.deployments.voice.create( worker_id="workerId", name="name", ) - assert_matches_type(VoiceCreateResponse, voice, path=["response"]) + assert_matches_type(VoiceDeployment, voice, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_method_create_with_all_params(self, async_client: AsyncBrainbase) -> None: voice = await async_client.workers.deployments.voice.create( @@ -319,9 +319,9 @@ async def test_method_create_with_all_params(self, async_client: AsyncBrainbase) voice_id="voiceId", voice_provider="voiceProvider", ) - assert_matches_type(VoiceCreateResponse, voice, path=["response"]) + assert_matches_type(VoiceDeployment, voice, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_raw_response_create(self, async_client: AsyncBrainbase) -> None: response = await async_client.workers.deployments.voice.with_raw_response.create( @@ -332,9 +332,9 @@ async def test_raw_response_create(self, async_client: AsyncBrainbase) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" voice = await response.parse() - assert_matches_type(VoiceCreateResponse, voice, path=["response"]) + assert_matches_type(VoiceDeployment, voice, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_streaming_response_create(self, async_client: AsyncBrainbase) -> None: async with async_client.workers.deployments.voice.with_streaming_response.create( @@ -345,11 +345,11 @@ async def test_streaming_response_create(self, async_client: AsyncBrainbase) -> assert response.http_request.headers.get("X-Stainless-Lang") == "python" voice = await response.parse() - assert_matches_type(VoiceCreateResponse, voice, path=["response"]) + assert_matches_type(VoiceDeployment, voice, path=["response"]) assert cast(Any, response.is_closed) is True - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_path_params_create(self, async_client: AsyncBrainbase) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"): @@ -358,16 +358,16 @@ async def test_path_params_create(self, async_client: AsyncBrainbase) -> None: name="name", ) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_method_retrieve(self, async_client: AsyncBrainbase) -> None: voice = await async_client.workers.deployments.voice.retrieve( deployment_id="deploymentId", worker_id="workerId", ) - assert_matches_type(VoiceRetrieveResponse, voice, path=["response"]) + assert_matches_type(VoiceDeployment, voice, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_raw_response_retrieve(self, async_client: AsyncBrainbase) -> None: response = await async_client.workers.deployments.voice.with_raw_response.retrieve( @@ -378,9 +378,9 @@ async def test_raw_response_retrieve(self, async_client: AsyncBrainbase) -> None assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" voice = await response.parse() - assert_matches_type(VoiceRetrieveResponse, voice, path=["response"]) + assert_matches_type(VoiceDeployment, voice, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_streaming_response_retrieve(self, async_client: AsyncBrainbase) -> None: async with async_client.workers.deployments.voice.with_streaming_response.retrieve( @@ -391,11 +391,11 @@ async def test_streaming_response_retrieve(self, async_client: AsyncBrainbase) - assert response.http_request.headers.get("X-Stainless-Lang") == "python" voice = await response.parse() - assert_matches_type(VoiceRetrieveResponse, voice, path=["response"]) + assert_matches_type(VoiceDeployment, voice, path=["response"]) assert cast(Any, response.is_closed) is True - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_path_params_retrieve(self, async_client: AsyncBrainbase) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"): @@ -410,7 +410,7 @@ async def test_path_params_retrieve(self, async_client: AsyncBrainbase) -> None: worker_id="workerId", ) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_method_update(self, async_client: AsyncBrainbase) -> None: voice = await async_client.workers.deployments.voice.update( @@ -418,9 +418,9 @@ async def test_method_update(self, async_client: AsyncBrainbase) -> None: worker_id="workerId", name="name", ) - assert_matches_type(VoiceUpdateResponse, voice, path=["response"]) + assert_matches_type(VoiceDeployment, voice, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_method_update_with_all_params(self, async_client: AsyncBrainbase) -> None: voice = await async_client.workers.deployments.voice.update( @@ -431,9 +431,9 @@ async def test_method_update_with_all_params(self, async_client: AsyncBrainbase) voice_id="voiceId", voice_provider="voiceProvider", ) - assert_matches_type(VoiceUpdateResponse, voice, path=["response"]) + assert_matches_type(VoiceDeployment, voice, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_raw_response_update(self, async_client: AsyncBrainbase) -> None: response = await async_client.workers.deployments.voice.with_raw_response.update( @@ -445,9 +445,9 @@ async def test_raw_response_update(self, async_client: AsyncBrainbase) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" voice = await response.parse() - assert_matches_type(VoiceUpdateResponse, voice, path=["response"]) + assert_matches_type(VoiceDeployment, voice, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_streaming_response_update(self, async_client: AsyncBrainbase) -> None: async with async_client.workers.deployments.voice.with_streaming_response.update( @@ -459,11 +459,11 @@ async def test_streaming_response_update(self, async_client: AsyncBrainbase) -> assert response.http_request.headers.get("X-Stainless-Lang") == "python" voice = await response.parse() - assert_matches_type(VoiceUpdateResponse, voice, path=["response"]) + assert_matches_type(VoiceDeployment, voice, path=["response"]) assert cast(Any, response.is_closed) is True - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_path_params_update(self, async_client: AsyncBrainbase) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"): @@ -480,7 +480,7 @@ async def test_path_params_update(self, async_client: AsyncBrainbase) -> None: name="name", ) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_method_list(self, async_client: AsyncBrainbase) -> None: voice = await async_client.workers.deployments.voice.list( @@ -488,7 +488,7 @@ async def test_method_list(self, async_client: AsyncBrainbase) -> None: ) assert_matches_type(VoiceListResponse, voice, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_raw_response_list(self, async_client: AsyncBrainbase) -> None: response = await async_client.workers.deployments.voice.with_raw_response.list( @@ -500,7 +500,7 @@ async def test_raw_response_list(self, async_client: AsyncBrainbase) -> None: voice = await response.parse() assert_matches_type(VoiceListResponse, voice, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_streaming_response_list(self, async_client: AsyncBrainbase) -> None: async with async_client.workers.deployments.voice.with_streaming_response.list( @@ -514,7 +514,7 @@ async def test_streaming_response_list(self, async_client: AsyncBrainbase) -> No assert cast(Any, response.is_closed) is True - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_path_params_list(self, async_client: AsyncBrainbase) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"): @@ -522,7 +522,7 @@ async def test_path_params_list(self, async_client: AsyncBrainbase) -> None: "", ) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_method_delete(self, async_client: AsyncBrainbase) -> None: voice = await async_client.workers.deployments.voice.delete( @@ -531,7 +531,7 @@ async def test_method_delete(self, async_client: AsyncBrainbase) -> None: ) assert voice is None - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_raw_response_delete(self, async_client: AsyncBrainbase) -> None: response = await async_client.workers.deployments.voice.with_raw_response.delete( @@ -544,7 +544,7 @@ async def test_raw_response_delete(self, async_client: AsyncBrainbase) -> None: voice = await response.parse() assert voice is None - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_streaming_response_delete(self, async_client: AsyncBrainbase) -> None: async with async_client.workers.deployments.voice.with_streaming_response.delete( @@ -559,7 +559,7 @@ async def test_streaming_response_delete(self, async_client: AsyncBrainbase) -> assert cast(Any, response.is_closed) is True - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_path_params_delete(self, async_client: AsyncBrainbase) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"): diff --git a/tests/api_resources/workers/test_flows.py b/tests/api_resources/workers/test_flows.py index e4d645ca..93cc926a 100644 --- a/tests/api_resources/workers/test_flows.py +++ b/tests/api_resources/workers/test_flows.py @@ -22,7 +22,7 @@ class TestFlows: parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_method_create(self, client: Brainbase) -> None: flow = client.workers.flows.create( @@ -32,7 +32,7 @@ def test_method_create(self, client: Brainbase) -> None: ) assert_matches_type(FlowCreateResponse, flow, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_method_create_with_all_params(self, client: Brainbase) -> None: flow = client.workers.flows.create( @@ -43,7 +43,7 @@ def test_method_create_with_all_params(self, client: Brainbase) -> None: ) assert_matches_type(FlowCreateResponse, flow, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_raw_response_create(self, client: Brainbase) -> None: response = client.workers.flows.with_raw_response.create( @@ -57,7 +57,7 @@ def test_raw_response_create(self, client: Brainbase) -> None: flow = response.parse() assert_matches_type(FlowCreateResponse, flow, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_streaming_response_create(self, client: Brainbase) -> None: with client.workers.flows.with_streaming_response.create( @@ -73,7 +73,7 @@ def test_streaming_response_create(self, client: Brainbase) -> None: assert cast(Any, response.is_closed) is True - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_path_params_create(self, client: Brainbase) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"): @@ -83,7 +83,7 @@ def test_path_params_create(self, client: Brainbase) -> None: name="name", ) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_method_retrieve(self, client: Brainbase) -> None: flow = client.workers.flows.retrieve( @@ -92,7 +92,7 @@ def test_method_retrieve(self, client: Brainbase) -> None: ) assert_matches_type(FlowRetrieveResponse, flow, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_raw_response_retrieve(self, client: Brainbase) -> None: response = client.workers.flows.with_raw_response.retrieve( @@ -105,7 +105,7 @@ def test_raw_response_retrieve(self, client: Brainbase) -> None: flow = response.parse() assert_matches_type(FlowRetrieveResponse, flow, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_streaming_response_retrieve(self, client: Brainbase) -> None: with client.workers.flows.with_streaming_response.retrieve( @@ -120,7 +120,7 @@ def test_streaming_response_retrieve(self, client: Brainbase) -> None: assert cast(Any, response.is_closed) is True - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_path_params_retrieve(self, client: Brainbase) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"): @@ -135,7 +135,7 @@ def test_path_params_retrieve(self, client: Brainbase) -> None: worker_id="workerId", ) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_method_update(self, client: Brainbase) -> None: flow = client.workers.flows.update( @@ -144,7 +144,7 @@ def test_method_update(self, client: Brainbase) -> None: ) assert_matches_type(FlowUpdateResponse, flow, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_method_update_with_all_params(self, client: Brainbase) -> None: flow = client.workers.flows.update( @@ -156,7 +156,7 @@ def test_method_update_with_all_params(self, client: Brainbase) -> None: ) assert_matches_type(FlowUpdateResponse, flow, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_raw_response_update(self, client: Brainbase) -> None: response = client.workers.flows.with_raw_response.update( @@ -169,7 +169,7 @@ def test_raw_response_update(self, client: Brainbase) -> None: flow = response.parse() assert_matches_type(FlowUpdateResponse, flow, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_streaming_response_update(self, client: Brainbase) -> None: with client.workers.flows.with_streaming_response.update( @@ -184,7 +184,7 @@ def test_streaming_response_update(self, client: Brainbase) -> None: assert cast(Any, response.is_closed) is True - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_path_params_update(self, client: Brainbase) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"): @@ -199,7 +199,7 @@ def test_path_params_update(self, client: Brainbase) -> None: worker_id="workerId", ) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_method_list(self, client: Brainbase) -> None: flow = client.workers.flows.list( @@ -207,7 +207,7 @@ def test_method_list(self, client: Brainbase) -> None: ) assert_matches_type(FlowListResponse, flow, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_raw_response_list(self, client: Brainbase) -> None: response = client.workers.flows.with_raw_response.list( @@ -219,7 +219,7 @@ def test_raw_response_list(self, client: Brainbase) -> None: flow = response.parse() assert_matches_type(FlowListResponse, flow, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_streaming_response_list(self, client: Brainbase) -> None: with client.workers.flows.with_streaming_response.list( @@ -233,7 +233,7 @@ def test_streaming_response_list(self, client: Brainbase) -> None: assert cast(Any, response.is_closed) is True - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_path_params_list(self, client: Brainbase) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"): @@ -241,7 +241,7 @@ def test_path_params_list(self, client: Brainbase) -> None: "", ) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_method_delete(self, client: Brainbase) -> None: flow = client.workers.flows.delete( @@ -250,7 +250,7 @@ def test_method_delete(self, client: Brainbase) -> None: ) assert flow is None - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_raw_response_delete(self, client: Brainbase) -> None: response = client.workers.flows.with_raw_response.delete( @@ -263,7 +263,7 @@ def test_raw_response_delete(self, client: Brainbase) -> None: flow = response.parse() assert flow is None - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_streaming_response_delete(self, client: Brainbase) -> None: with client.workers.flows.with_streaming_response.delete( @@ -278,7 +278,7 @@ def test_streaming_response_delete(self, client: Brainbase) -> None: assert cast(Any, response.is_closed) is True - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_path_params_delete(self, client: Brainbase) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"): @@ -295,9 +295,11 @@ def test_path_params_delete(self, client: Brainbase) -> None: class TestAsyncFlows: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_method_create(self, async_client: AsyncBrainbase) -> None: flow = await async_client.workers.flows.create( @@ -307,7 +309,7 @@ async def test_method_create(self, async_client: AsyncBrainbase) -> None: ) assert_matches_type(FlowCreateResponse, flow, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_method_create_with_all_params(self, async_client: AsyncBrainbase) -> None: flow = await async_client.workers.flows.create( @@ -318,7 +320,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncBrainbase) ) assert_matches_type(FlowCreateResponse, flow, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_raw_response_create(self, async_client: AsyncBrainbase) -> None: response = await async_client.workers.flows.with_raw_response.create( @@ -332,7 +334,7 @@ async def test_raw_response_create(self, async_client: AsyncBrainbase) -> None: flow = await response.parse() assert_matches_type(FlowCreateResponse, flow, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_streaming_response_create(self, async_client: AsyncBrainbase) -> None: async with async_client.workers.flows.with_streaming_response.create( @@ -348,7 +350,7 @@ async def test_streaming_response_create(self, async_client: AsyncBrainbase) -> assert cast(Any, response.is_closed) is True - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_path_params_create(self, async_client: AsyncBrainbase) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"): @@ -358,7 +360,7 @@ async def test_path_params_create(self, async_client: AsyncBrainbase) -> None: name="name", ) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_method_retrieve(self, async_client: AsyncBrainbase) -> None: flow = await async_client.workers.flows.retrieve( @@ -367,7 +369,7 @@ async def test_method_retrieve(self, async_client: AsyncBrainbase) -> None: ) assert_matches_type(FlowRetrieveResponse, flow, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_raw_response_retrieve(self, async_client: AsyncBrainbase) -> None: response = await async_client.workers.flows.with_raw_response.retrieve( @@ -380,7 +382,7 @@ async def test_raw_response_retrieve(self, async_client: AsyncBrainbase) -> None flow = await response.parse() assert_matches_type(FlowRetrieveResponse, flow, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_streaming_response_retrieve(self, async_client: AsyncBrainbase) -> None: async with async_client.workers.flows.with_streaming_response.retrieve( @@ -395,7 +397,7 @@ async def test_streaming_response_retrieve(self, async_client: AsyncBrainbase) - assert cast(Any, response.is_closed) is True - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_path_params_retrieve(self, async_client: AsyncBrainbase) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"): @@ -410,7 +412,7 @@ async def test_path_params_retrieve(self, async_client: AsyncBrainbase) -> None: worker_id="workerId", ) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_method_update(self, async_client: AsyncBrainbase) -> None: flow = await async_client.workers.flows.update( @@ -419,7 +421,7 @@ async def test_method_update(self, async_client: AsyncBrainbase) -> None: ) assert_matches_type(FlowUpdateResponse, flow, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_method_update_with_all_params(self, async_client: AsyncBrainbase) -> None: flow = await async_client.workers.flows.update( @@ -431,7 +433,7 @@ async def test_method_update_with_all_params(self, async_client: AsyncBrainbase) ) assert_matches_type(FlowUpdateResponse, flow, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_raw_response_update(self, async_client: AsyncBrainbase) -> None: response = await async_client.workers.flows.with_raw_response.update( @@ -444,7 +446,7 @@ async def test_raw_response_update(self, async_client: AsyncBrainbase) -> None: flow = await response.parse() assert_matches_type(FlowUpdateResponse, flow, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_streaming_response_update(self, async_client: AsyncBrainbase) -> None: async with async_client.workers.flows.with_streaming_response.update( @@ -459,7 +461,7 @@ async def test_streaming_response_update(self, async_client: AsyncBrainbase) -> assert cast(Any, response.is_closed) is True - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_path_params_update(self, async_client: AsyncBrainbase) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"): @@ -474,7 +476,7 @@ async def test_path_params_update(self, async_client: AsyncBrainbase) -> None: worker_id="workerId", ) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_method_list(self, async_client: AsyncBrainbase) -> None: flow = await async_client.workers.flows.list( @@ -482,7 +484,7 @@ async def test_method_list(self, async_client: AsyncBrainbase) -> None: ) assert_matches_type(FlowListResponse, flow, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_raw_response_list(self, async_client: AsyncBrainbase) -> None: response = await async_client.workers.flows.with_raw_response.list( @@ -494,7 +496,7 @@ async def test_raw_response_list(self, async_client: AsyncBrainbase) -> None: flow = await response.parse() assert_matches_type(FlowListResponse, flow, path=["response"]) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_streaming_response_list(self, async_client: AsyncBrainbase) -> None: async with async_client.workers.flows.with_streaming_response.list( @@ -508,7 +510,7 @@ async def test_streaming_response_list(self, async_client: AsyncBrainbase) -> No assert cast(Any, response.is_closed) is True - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_path_params_list(self, async_client: AsyncBrainbase) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"): @@ -516,7 +518,7 @@ async def test_path_params_list(self, async_client: AsyncBrainbase) -> None: "", ) - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_method_delete(self, async_client: AsyncBrainbase) -> None: flow = await async_client.workers.flows.delete( @@ -525,7 +527,7 @@ async def test_method_delete(self, async_client: AsyncBrainbase) -> None: ) assert flow is None - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_raw_response_delete(self, async_client: AsyncBrainbase) -> None: response = await async_client.workers.flows.with_raw_response.delete( @@ -538,7 +540,7 @@ async def test_raw_response_delete(self, async_client: AsyncBrainbase) -> None: flow = await response.parse() assert flow is None - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_streaming_response_delete(self, async_client: AsyncBrainbase) -> None: async with async_client.workers.flows.with_streaming_response.delete( @@ -553,7 +555,7 @@ async def test_streaming_response_delete(self, async_client: AsyncBrainbase) -> assert cast(Any, response.is_closed) is True - @pytest.mark.skip() + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_path_params_delete(self, async_client: AsyncBrainbase) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"): diff --git a/tests/conftest.py b/tests/conftest.py index 8e89d982..7a6d4de1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,16 +1,20 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + from __future__ import annotations import os import logging from typing import TYPE_CHECKING, Iterator, AsyncIterator +import httpx import pytest from pytest_asyncio import is_async_test -from brainbase import Brainbase, AsyncBrainbase +from brainbase import Brainbase, AsyncBrainbase, DefaultAioHttpClient +from brainbase._utils import is_dict if TYPE_CHECKING: - from _pytest.fixtures import FixtureRequest + from _pytest.fixtures import FixtureRequest # pyright: ignore[reportPrivateImportUsage] pytest.register_assert_rewrite("tests.utils") @@ -25,6 +29,19 @@ def pytest_collection_modifyitems(items: list[pytest.Function]) -> None: for async_test in pytest_asyncio_tests: async_test.add_marker(session_scope_marker, append=False) + # We skip tests that use both the aiohttp client and respx_mock as respx_mock + # doesn't support custom transports. + for item in items: + if "async_client" not in item.fixturenames or "respx_mock" not in item.fixturenames: + continue + + if not hasattr(item, "callspec"): + continue + + async_client_param = item.callspec.params.get("async_client") + if is_dict(async_client_param) and async_client_param.get("http_client") == "aiohttp": + item.add_marker(pytest.mark.skip(reason="aiohttp client is not compatible with respx_mock")) + base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -43,9 +60,25 @@ def client(request: FixtureRequest) -> Iterator[Brainbase]: @pytest.fixture(scope="session") async def async_client(request: FixtureRequest) -> AsyncIterator[AsyncBrainbase]: - strict = getattr(request, "param", True) - if not isinstance(strict, bool): - raise TypeError(f"Unexpected fixture parameter type {type(strict)}, expected {bool}") - - async with AsyncBrainbase(base_url=base_url, api_key=api_key, _strict_response_validation=strict) as client: + param = getattr(request, "param", True) + + # defaults + strict = True + http_client: None | httpx.AsyncClient = None + + if isinstance(param, bool): + strict = param + elif is_dict(param): + strict = param.get("strict", True) + assert isinstance(strict, bool) + + http_client_type = param.get("http_client", "httpx") + if http_client_type == "aiohttp": + http_client = DefaultAioHttpClient() + else: + raise TypeError(f"Unexpected fixture parameter type {type(param)}, expected bool or dict") + + async with AsyncBrainbase( + base_url=base_url, api_key=api_key, _strict_response_validation=strict, http_client=http_client + ) as client: yield client diff --git a/tests/test_client.py b/tests/test_client.py index c3570edb..6f35f9d0 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -6,13 +6,10 @@ import os import sys import json -import time import asyncio import inspect -import subprocess import tracemalloc from typing import Any, Union, cast -from textwrap import dedent from unittest import mock from typing_extensions import Literal @@ -23,13 +20,17 @@ from brainbase import Brainbase, AsyncBrainbase, APIResponseValidationError from brainbase._types import Omit +from brainbase._utils import asyncify from brainbase._models import BaseModel, FinalRequestOptions -from brainbase._constants import RAW_RESPONSE_HEADER from brainbase._exceptions import APIStatusError, BrainbaseError, APITimeoutError, APIResponseValidationError from brainbase._base_client import ( DEFAULT_TIMEOUT, HTTPX_DEFAULT_TIMEOUT, BaseClient, + OtherPlatform, + DefaultHttpxClient, + DefaultAsyncHttpxClient, + get_platform, make_request_options, ) @@ -58,51 +59,49 @@ def _get_open_connections(client: Brainbase | AsyncBrainbase) -> int: class TestBrainbase: - client = Brainbase(base_url=base_url, api_key=api_key, _strict_response_validation=True) - @pytest.mark.respx(base_url=base_url) - def test_raw_response(self, respx_mock: MockRouter) -> None: + def test_raw_response(self, respx_mock: MockRouter, client: Brainbase) -> None: respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = self.client.post("/foo", cast_to=httpx.Response) + response = client.post("/foo", cast_to=httpx.Response) assert response.status_code == 200 assert isinstance(response, httpx.Response) assert response.json() == {"foo": "bar"} @pytest.mark.respx(base_url=base_url) - def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None: + def test_raw_response_for_binary(self, respx_mock: MockRouter, client: Brainbase) -> None: respx_mock.post("/foo").mock( return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}') ) - response = self.client.post("/foo", cast_to=httpx.Response) + response = client.post("/foo", cast_to=httpx.Response) assert response.status_code == 200 assert isinstance(response, httpx.Response) assert response.json() == {"foo": "bar"} - def test_copy(self) -> None: - copied = self.client.copy() - assert id(copied) != id(self.client) + def test_copy(self, client: Brainbase) -> None: + copied = client.copy() + assert id(copied) != id(client) - copied = self.client.copy(api_key="another My API Key") + copied = client.copy(api_key="another My API Key") assert copied.api_key == "another My API Key" - assert self.client.api_key == "My API Key" + assert client.api_key == "My API Key" - def test_copy_default_options(self) -> None: + def test_copy_default_options(self, client: Brainbase) -> None: # options that have a default are overridden correctly - copied = self.client.copy(max_retries=7) + copied = client.copy(max_retries=7) assert copied.max_retries == 7 - assert self.client.max_retries == 2 + assert client.max_retries == 2 copied2 = copied.copy(max_retries=6) assert copied2.max_retries == 6 assert copied.max_retries == 7 # timeout - assert isinstance(self.client.timeout, httpx.Timeout) - copied = self.client.copy(timeout=None) + assert isinstance(client.timeout, httpx.Timeout) + copied = client.copy(timeout=None) assert copied.timeout is None - assert isinstance(self.client.timeout, httpx.Timeout) + assert isinstance(client.timeout, httpx.Timeout) def test_copy_default_headers(self) -> None: client = Brainbase( @@ -137,6 +136,7 @@ def test_copy_default_headers(self) -> None: match="`default_headers` and `set_default_headers` arguments are mutually exclusive", ): client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"}) + client.close() def test_copy_default_query(self) -> None: client = Brainbase( @@ -174,13 +174,15 @@ def test_copy_default_query(self) -> None: ): client.copy(set_default_query={}, default_query={"foo": "Bar"}) - def test_copy_signature(self) -> None: + client.close() + + def test_copy_signature(self, client: Brainbase) -> None: # ensure the same parameters that can be passed to the client are defined in the `.copy()` method init_signature = inspect.signature( # mypy doesn't like that we access the `__init__` property. - self.client.__init__, # type: ignore[misc] + client.__init__, # type: ignore[misc] ) - copy_signature = inspect.signature(self.client.copy) + copy_signature = inspect.signature(client.copy) exclude_params = {"transport", "proxies", "_strict_response_validation"} for name in init_signature.parameters.keys(): @@ -190,12 +192,13 @@ def test_copy_signature(self) -> None: copy_param = copy_signature.parameters.get(name) assert copy_param is not None, f"copy() signature is missing the {name} param" - def test_copy_build_request(self) -> None: + @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12") + def test_copy_build_request(self, client: Brainbase) -> None: options = FinalRequestOptions(method="get", url="/foo") def build_request(options: FinalRequestOptions) -> None: - client = self.client.copy() - client._build_request(options) + client_copy = client.copy() + client_copy._build_request(options) # ensure that the machinery is warmed up before tracing starts. build_request(options) @@ -252,14 +255,12 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic print(frame) raise AssertionError() - def test_request_timeout(self) -> None: - request = self.client._build_request(FinalRequestOptions(method="get", url="/foo")) + def test_request_timeout(self, client: Brainbase) -> None: + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT - request = self.client._build_request( - FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0)) - ) + request = client._build_request(FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(100.0) @@ -272,6 +273,8 @@ def test_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(0) + client.close() + def test_http_client_timeout_option(self) -> None: # custom timeout given to the httpx client should be used with httpx.Client(timeout=None) as http_client: @@ -283,6 +286,8 @@ def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(None) + client.close() + # no timeout given to the httpx client should not use the httpx default with httpx.Client() as http_client: client = Brainbase( @@ -293,6 +298,8 @@ def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT + client.close() + # explicitly passing the default timeout currently results in it being ignored with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client: client = Brainbase( @@ -303,6 +310,8 @@ def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT # our default + client.close() + async def test_invalid_http_client(self) -> None: with pytest.raises(TypeError, match="Invalid `http_client` arg"): async with httpx.AsyncClient() as http_client: @@ -314,14 +323,14 @@ async def test_invalid_http_client(self) -> None: ) def test_default_headers_option(self) -> None: - client = Brainbase( + test_client = Brainbase( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"} ) - request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + request = test_client._build_request(FinalRequestOptions(method="get", url="/foo")) assert request.headers.get("x-foo") == "bar" assert request.headers.get("x-stainless-lang") == "python" - client2 = Brainbase( + test_client2 = Brainbase( base_url=base_url, api_key=api_key, _strict_response_validation=True, @@ -330,10 +339,13 @@ def test_default_headers_option(self) -> None: "X-Stainless-Lang": "my-overriding-header", }, ) - request = client2._build_request(FinalRequestOptions(method="get", url="/foo")) + request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo")) assert request.headers.get("x-foo") == "stainless" assert request.headers.get("x-stainless-lang") == "my-overriding-header" + test_client.close() + test_client2.close() + def test_validate_headers(self) -> None: client = Brainbase(base_url=base_url, api_key=api_key, _strict_response_validation=True) request = client._build_request(FinalRequestOptions(method="get", url="/foo")) @@ -362,8 +374,10 @@ def test_default_query_option(self) -> None: url = httpx.URL(request.url) assert dict(url.params) == {"foo": "baz", "query_param": "overridden"} - def test_request_extra_json(self) -> None: - request = self.client._build_request( + client.close() + + def test_request_extra_json(self, client: Brainbase) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -374,7 +388,7 @@ def test_request_extra_json(self) -> None: data = json.loads(request.content.decode("utf-8")) assert data == {"foo": "bar", "baz": False} - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -385,7 +399,7 @@ def test_request_extra_json(self) -> None: assert data == {"baz": False} # `extra_json` takes priority over `json_data` when keys clash - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -396,8 +410,8 @@ def test_request_extra_json(self) -> None: data = json.loads(request.content.decode("utf-8")) assert data == {"foo": "bar", "baz": None} - def test_request_extra_headers(self) -> None: - request = self.client._build_request( + def test_request_extra_headers(self, client: Brainbase) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -407,7 +421,7 @@ def test_request_extra_headers(self) -> None: assert request.headers.get("X-Foo") == "Foo" # `extra_headers` takes priority over `default_headers` when keys clash - request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request( + request = client.with_options(default_headers={"X-Bar": "true"})._build_request( FinalRequestOptions( method="post", url="/foo", @@ -418,8 +432,8 @@ def test_request_extra_headers(self) -> None: ) assert request.headers.get("X-Bar") == "false" - def test_request_extra_query(self) -> None: - request = self.client._build_request( + def test_request_extra_query(self, client: Brainbase) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -432,7 +446,7 @@ def test_request_extra_query(self) -> None: assert params == {"my_query_param": "Foo"} # if both `query` and `extra_query` are given, they are merged - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -446,7 +460,7 @@ def test_request_extra_query(self) -> None: assert params == {"bar": "1", "foo": "2"} # `extra_query` takes priority over `query` when keys clash - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -462,7 +476,7 @@ def test_request_extra_query(self) -> None: def test_multipart_repeating_array(self, client: Brainbase) -> None: request = client._build_request( FinalRequestOptions.construct( - method="get", + method="post", url="/foo", headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"}, json_data={"array": ["foo", "bar"]}, @@ -489,7 +503,7 @@ def test_multipart_repeating_array(self, client: Brainbase) -> None: ] @pytest.mark.respx(base_url=base_url) - def test_basic_union_response(self, respx_mock: MockRouter) -> None: + def test_basic_union_response(self, respx_mock: MockRouter, client: Brainbase) -> None: class Model1(BaseModel): name: str @@ -498,12 +512,12 @@ class Model2(BaseModel): respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model2) assert response.foo == "bar" @pytest.mark.respx(base_url=base_url) - def test_union_response_different_types(self, respx_mock: MockRouter) -> None: + def test_union_response_different_types(self, respx_mock: MockRouter, client: Brainbase) -> None: """Union of objects with the same field name using a different type""" class Model1(BaseModel): @@ -514,18 +528,18 @@ class Model2(BaseModel): respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model2) assert response.foo == "bar" respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1})) - response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model1) assert response.foo == 1 @pytest.mark.respx(base_url=base_url) - def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None: + def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter, client: Brainbase) -> None: """ Response that sets Content-Type to something other than application/json but returns json data """ @@ -541,7 +555,7 @@ class Model(BaseModel): ) ) - response = self.client.get("/foo", cast_to=Model) + response = client.get("/foo", cast_to=Model) assert isinstance(response, Model) assert response.foo == 2 @@ -553,6 +567,8 @@ def test_base_url_setter(self) -> None: assert client.base_url == "https://example.com/from_setter/" + client.close() + def test_base_url_env(self) -> None: with update_env(BRAINBASE_BASE_URL="http://localhost:5000/from/env"): client = Brainbase(api_key=api_key, _strict_response_validation=True) @@ -580,6 +596,7 @@ def test_base_url_trailing_slash(self, client: Brainbase) -> None: ), ) assert request.url == "http://localhost:5000/custom/path/foo" + client.close() @pytest.mark.parametrize( "client", @@ -603,6 +620,7 @@ def test_base_url_no_trailing_slash(self, client: Brainbase) -> None: ), ) assert request.url == "http://localhost:5000/custom/path/foo" + client.close() @pytest.mark.parametrize( "client", @@ -626,35 +644,36 @@ def test_absolute_request_url(self, client: Brainbase) -> None: ), ) assert request.url == "https://myapi.com/foo" + client.close() def test_copied_client_does_not_close_http(self) -> None: - client = Brainbase(base_url=base_url, api_key=api_key, _strict_response_validation=True) - assert not client.is_closed() + test_client = Brainbase(base_url=base_url, api_key=api_key, _strict_response_validation=True) + assert not test_client.is_closed() - copied = client.copy() - assert copied is not client + copied = test_client.copy() + assert copied is not test_client del copied - assert not client.is_closed() + assert not test_client.is_closed() def test_client_context_manager(self) -> None: - client = Brainbase(base_url=base_url, api_key=api_key, _strict_response_validation=True) - with client as c2: - assert c2 is client + test_client = Brainbase(base_url=base_url, api_key=api_key, _strict_response_validation=True) + with test_client as c2: + assert c2 is test_client assert not c2.is_closed() - assert not client.is_closed() - assert client.is_closed() + assert not test_client.is_closed() + assert test_client.is_closed() @pytest.mark.respx(base_url=base_url) - def test_client_response_validation_error(self, respx_mock: MockRouter) -> None: + def test_client_response_validation_error(self, respx_mock: MockRouter, client: Brainbase) -> None: class Model(BaseModel): foo: str respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}})) with pytest.raises(APIResponseValidationError) as exc: - self.client.get("/foo", cast_to=Model) + client.get("/foo", cast_to=Model) assert isinstance(exc.value.__cause__, ValidationError) @@ -674,11 +693,14 @@ class Model(BaseModel): with pytest.raises(APIResponseValidationError): strict_client.get("/foo", cast_to=Model) - client = Brainbase(base_url=base_url, api_key=api_key, _strict_response_validation=False) + non_strict_client = Brainbase(base_url=base_url, api_key=api_key, _strict_response_validation=False) - response = client.get("/foo", cast_to=Model) + response = non_strict_client.get("/foo", cast_to=Model) assert isinstance(response, str) # type: ignore[unreachable] + strict_client.close() + non_strict_client.close() + @pytest.mark.parametrize( "remaining_retries,retry_after,timeout", [ @@ -701,9 +723,9 @@ class Model(BaseModel): ], ) @mock.patch("time.time", mock.MagicMock(return_value=1696004797)) - def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None: - client = Brainbase(base_url=base_url, api_key=api_key, _strict_response_validation=True) - + def test_parse_retry_after_header( + self, remaining_retries: int, retry_after: str, timeout: float, client: Brainbase + ) -> None: headers = httpx.Headers({"retry-after": retry_after}) options = FinalRequestOptions(method="get", url="/foo", max_retries=3) calculated = client._calculate_retry_timeout(remaining_retries, options, headers) @@ -711,27 +733,22 @@ def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str @mock.patch("brainbase._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: + def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, client: Brainbase) -> None: respx_mock.get("/api/workers").mock(side_effect=httpx.TimeoutException("Test timeout error")) with pytest.raises(APITimeoutError): - self.client.get( - "/api/workers", cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}} - ) + client.workers.with_streaming_response.list().__enter__() - assert _get_open_connections(self.client) == 0 + assert _get_open_connections(client) == 0 @mock.patch("brainbase._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: + def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client: Brainbase) -> None: respx_mock.get("/api/workers").mock(return_value=httpx.Response(500)) with pytest.raises(APIStatusError): - self.client.get( - "/api/workers", cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}} - ) - - assert _get_open_connections(self.client) == 0 + client.workers.with_streaming_response.list().__enter__() + assert _get_open_connections(client) == 0 @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @mock.patch("brainbase._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @@ -810,57 +827,100 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: assert response.http_request.headers.get("x-stainless-retry-count") == "42" + def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Test that the proxy environment variables are set correctly + monkeypatch.setenv("HTTPS_PROXY", "https://example.org") + + client = DefaultHttpxClient() + + mounts = tuple(client._mounts.items()) + assert len(mounts) == 1 + assert mounts[0][0].pattern == "https://" + + @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning") + def test_default_client_creation(self) -> None: + # Ensure that the client can be initialized without any exceptions + DefaultHttpxClient( + verify=True, + cert=None, + trust_env=True, + http1=True, + http2=False, + limits=httpx.Limits(max_connections=100, max_keepalive_connections=20), + ) -class TestAsyncBrainbase: - client = AsyncBrainbase(base_url=base_url, api_key=api_key, _strict_response_validation=True) + @pytest.mark.respx(base_url=base_url) + def test_follow_redirects(self, respx_mock: MockRouter, client: Brainbase) -> None: + # Test that the default follow_redirects=True allows following redirects + respx_mock.post("/redirect").mock( + return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) + ) + respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"})) + + response = client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response) + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + @pytest.mark.respx(base_url=base_url) + def test_follow_redirects_disabled(self, respx_mock: MockRouter, client: Brainbase) -> None: + # Test that follow_redirects=False prevents following redirects + respx_mock.post("/redirect").mock( + return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) + ) + with pytest.raises(APIStatusError) as exc_info: + client.post("/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response) + + assert exc_info.value.response.status_code == 302 + assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected" + + +class TestAsyncBrainbase: @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio - async def test_raw_response(self, respx_mock: MockRouter) -> None: + async def test_raw_response(self, respx_mock: MockRouter, async_client: AsyncBrainbase) -> None: respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = await self.client.post("/foo", cast_to=httpx.Response) + response = await async_client.post("/foo", cast_to=httpx.Response) assert response.status_code == 200 assert isinstance(response, httpx.Response) assert response.json() == {"foo": "bar"} @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio - async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None: + async def test_raw_response_for_binary(self, respx_mock: MockRouter, async_client: AsyncBrainbase) -> None: respx_mock.post("/foo").mock( return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}') ) - response = await self.client.post("/foo", cast_to=httpx.Response) + response = await async_client.post("/foo", cast_to=httpx.Response) assert response.status_code == 200 assert isinstance(response, httpx.Response) assert response.json() == {"foo": "bar"} - def test_copy(self) -> None: - copied = self.client.copy() - assert id(copied) != id(self.client) + def test_copy(self, async_client: AsyncBrainbase) -> None: + copied = async_client.copy() + assert id(copied) != id(async_client) - copied = self.client.copy(api_key="another My API Key") + copied = async_client.copy(api_key="another My API Key") assert copied.api_key == "another My API Key" - assert self.client.api_key == "My API Key" + assert async_client.api_key == "My API Key" - def test_copy_default_options(self) -> None: + def test_copy_default_options(self, async_client: AsyncBrainbase) -> None: # options that have a default are overridden correctly - copied = self.client.copy(max_retries=7) + copied = async_client.copy(max_retries=7) assert copied.max_retries == 7 - assert self.client.max_retries == 2 + assert async_client.max_retries == 2 copied2 = copied.copy(max_retries=6) assert copied2.max_retries == 6 assert copied.max_retries == 7 # timeout - assert isinstance(self.client.timeout, httpx.Timeout) - copied = self.client.copy(timeout=None) + assert isinstance(async_client.timeout, httpx.Timeout) + copied = async_client.copy(timeout=None) assert copied.timeout is None - assert isinstance(self.client.timeout, httpx.Timeout) + assert isinstance(async_client.timeout, httpx.Timeout) - def test_copy_default_headers(self) -> None: + async def test_copy_default_headers(self) -> None: client = AsyncBrainbase( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"} ) @@ -893,8 +953,9 @@ def test_copy_default_headers(self) -> None: match="`default_headers` and `set_default_headers` arguments are mutually exclusive", ): client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"}) + await client.close() - def test_copy_default_query(self) -> None: + async def test_copy_default_query(self) -> None: client = AsyncBrainbase( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"} ) @@ -930,13 +991,15 @@ def test_copy_default_query(self) -> None: ): client.copy(set_default_query={}, default_query={"foo": "Bar"}) - def test_copy_signature(self) -> None: + await client.close() + + def test_copy_signature(self, async_client: AsyncBrainbase) -> None: # ensure the same parameters that can be passed to the client are defined in the `.copy()` method init_signature = inspect.signature( # mypy doesn't like that we access the `__init__` property. - self.client.__init__, # type: ignore[misc] + async_client.__init__, # type: ignore[misc] ) - copy_signature = inspect.signature(self.client.copy) + copy_signature = inspect.signature(async_client.copy) exclude_params = {"transport", "proxies", "_strict_response_validation"} for name in init_signature.parameters.keys(): @@ -946,12 +1009,13 @@ def test_copy_signature(self) -> None: copy_param = copy_signature.parameters.get(name) assert copy_param is not None, f"copy() signature is missing the {name} param" - def test_copy_build_request(self) -> None: + @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12") + def test_copy_build_request(self, async_client: AsyncBrainbase) -> None: options = FinalRequestOptions(method="get", url="/foo") def build_request(options: FinalRequestOptions) -> None: - client = self.client.copy() - client._build_request(options) + client_copy = async_client.copy() + client_copy._build_request(options) # ensure that the machinery is warmed up before tracing starts. build_request(options) @@ -1008,12 +1072,12 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic print(frame) raise AssertionError() - async def test_request_timeout(self) -> None: - request = self.client._build_request(FinalRequestOptions(method="get", url="/foo")) + async def test_request_timeout(self, async_client: AsyncBrainbase) -> None: + request = async_client._build_request(FinalRequestOptions(method="get", url="/foo")) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT - request = self.client._build_request( + request = async_client._build_request( FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0)) ) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore @@ -1028,6 +1092,8 @@ async def test_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(0) + await client.close() + async def test_http_client_timeout_option(self) -> None: # custom timeout given to the httpx client should be used async with httpx.AsyncClient(timeout=None) as http_client: @@ -1039,6 +1105,8 @@ async def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(None) + await client.close() + # no timeout given to the httpx client should not use the httpx default async with httpx.AsyncClient() as http_client: client = AsyncBrainbase( @@ -1049,6 +1117,8 @@ async def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT + await client.close() + # explicitly passing the default timeout currently results in it being ignored async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client: client = AsyncBrainbase( @@ -1059,6 +1129,8 @@ async def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT # our default + await client.close() + def test_invalid_http_client(self) -> None: with pytest.raises(TypeError, match="Invalid `http_client` arg"): with httpx.Client() as http_client: @@ -1069,15 +1141,15 @@ def test_invalid_http_client(self) -> None: http_client=cast(Any, http_client), ) - def test_default_headers_option(self) -> None: - client = AsyncBrainbase( + async def test_default_headers_option(self) -> None: + test_client = AsyncBrainbase( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"} ) - request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + request = test_client._build_request(FinalRequestOptions(method="get", url="/foo")) assert request.headers.get("x-foo") == "bar" assert request.headers.get("x-stainless-lang") == "python" - client2 = AsyncBrainbase( + test_client2 = AsyncBrainbase( base_url=base_url, api_key=api_key, _strict_response_validation=True, @@ -1086,10 +1158,13 @@ def test_default_headers_option(self) -> None: "X-Stainless-Lang": "my-overriding-header", }, ) - request = client2._build_request(FinalRequestOptions(method="get", url="/foo")) + request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo")) assert request.headers.get("x-foo") == "stainless" assert request.headers.get("x-stainless-lang") == "my-overriding-header" + await test_client.close() + await test_client2.close() + def test_validate_headers(self) -> None: client = AsyncBrainbase(base_url=base_url, api_key=api_key, _strict_response_validation=True) request = client._build_request(FinalRequestOptions(method="get", url="/foo")) @@ -1100,7 +1175,7 @@ def test_validate_headers(self) -> None: client2 = AsyncBrainbase(base_url=base_url, api_key=None, _strict_response_validation=True) _ = client2 - def test_default_query_option(self) -> None: + async def test_default_query_option(self) -> None: client = AsyncBrainbase( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"} ) @@ -1118,8 +1193,10 @@ def test_default_query_option(self) -> None: url = httpx.URL(request.url) assert dict(url.params) == {"foo": "baz", "query_param": "overridden"} - def test_request_extra_json(self) -> None: - request = self.client._build_request( + await client.close() + + def test_request_extra_json(self, client: Brainbase) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1130,7 +1207,7 @@ def test_request_extra_json(self) -> None: data = json.loads(request.content.decode("utf-8")) assert data == {"foo": "bar", "baz": False} - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1141,7 +1218,7 @@ def test_request_extra_json(self) -> None: assert data == {"baz": False} # `extra_json` takes priority over `json_data` when keys clash - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1152,8 +1229,8 @@ def test_request_extra_json(self) -> None: data = json.loads(request.content.decode("utf-8")) assert data == {"foo": "bar", "baz": None} - def test_request_extra_headers(self) -> None: - request = self.client._build_request( + def test_request_extra_headers(self, client: Brainbase) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1163,7 +1240,7 @@ def test_request_extra_headers(self) -> None: assert request.headers.get("X-Foo") == "Foo" # `extra_headers` takes priority over `default_headers` when keys clash - request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request( + request = client.with_options(default_headers={"X-Bar": "true"})._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1174,8 +1251,8 @@ def test_request_extra_headers(self) -> None: ) assert request.headers.get("X-Bar") == "false" - def test_request_extra_query(self) -> None: - request = self.client._build_request( + def test_request_extra_query(self, client: Brainbase) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1188,7 +1265,7 @@ def test_request_extra_query(self) -> None: assert params == {"my_query_param": "Foo"} # if both `query` and `extra_query` are given, they are merged - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1202,7 +1279,7 @@ def test_request_extra_query(self) -> None: assert params == {"bar": "1", "foo": "2"} # `extra_query` takes priority over `query` when keys clash - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1218,7 +1295,7 @@ def test_request_extra_query(self) -> None: def test_multipart_repeating_array(self, async_client: AsyncBrainbase) -> None: request = async_client._build_request( FinalRequestOptions.construct( - method="get", + method="post", url="/foo", headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"}, json_data={"array": ["foo", "bar"]}, @@ -1245,7 +1322,7 @@ def test_multipart_repeating_array(self, async_client: AsyncBrainbase) -> None: ] @pytest.mark.respx(base_url=base_url) - async def test_basic_union_response(self, respx_mock: MockRouter) -> None: + async def test_basic_union_response(self, respx_mock: MockRouter, async_client: AsyncBrainbase) -> None: class Model1(BaseModel): name: str @@ -1254,12 +1331,12 @@ class Model2(BaseModel): respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model2) assert response.foo == "bar" @pytest.mark.respx(base_url=base_url) - async def test_union_response_different_types(self, respx_mock: MockRouter) -> None: + async def test_union_response_different_types(self, respx_mock: MockRouter, async_client: AsyncBrainbase) -> None: """Union of objects with the same field name using a different type""" class Model1(BaseModel): @@ -1270,18 +1347,20 @@ class Model2(BaseModel): respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model2) assert response.foo == "bar" respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1})) - response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model1) assert response.foo == 1 @pytest.mark.respx(base_url=base_url) - async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None: + async def test_non_application_json_content_type_for_json_data( + self, respx_mock: MockRouter, async_client: AsyncBrainbase + ) -> None: """ Response that sets Content-Type to something other than application/json but returns json data """ @@ -1297,11 +1376,11 @@ class Model(BaseModel): ) ) - response = await self.client.get("/foo", cast_to=Model) + response = await async_client.get("/foo", cast_to=Model) assert isinstance(response, Model) assert response.foo == 2 - def test_base_url_setter(self) -> None: + async def test_base_url_setter(self) -> None: client = AsyncBrainbase( base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True ) @@ -1311,7 +1390,9 @@ def test_base_url_setter(self) -> None: assert client.base_url == "https://example.com/from_setter/" - def test_base_url_env(self) -> None: + await client.close() + + async def test_base_url_env(self) -> None: with update_env(BRAINBASE_BASE_URL="http://localhost:5000/from/env"): client = AsyncBrainbase(api_key=api_key, _strict_response_validation=True) assert client.base_url == "http://localhost:5000/from/env/" @@ -1331,7 +1412,7 @@ def test_base_url_env(self) -> None: ], ids=["standard", "custom http client"], ) - def test_base_url_trailing_slash(self, client: AsyncBrainbase) -> None: + async def test_base_url_trailing_slash(self, client: AsyncBrainbase) -> None: request = client._build_request( FinalRequestOptions( method="post", @@ -1340,6 +1421,7 @@ def test_base_url_trailing_slash(self, client: AsyncBrainbase) -> None: ), ) assert request.url == "http://localhost:5000/custom/path/foo" + await client.close() @pytest.mark.parametrize( "client", @@ -1356,7 +1438,7 @@ def test_base_url_trailing_slash(self, client: AsyncBrainbase) -> None: ], ids=["standard", "custom http client"], ) - def test_base_url_no_trailing_slash(self, client: AsyncBrainbase) -> None: + async def test_base_url_no_trailing_slash(self, client: AsyncBrainbase) -> None: request = client._build_request( FinalRequestOptions( method="post", @@ -1365,6 +1447,7 @@ def test_base_url_no_trailing_slash(self, client: AsyncBrainbase) -> None: ), ) assert request.url == "http://localhost:5000/custom/path/foo" + await client.close() @pytest.mark.parametrize( "client", @@ -1381,7 +1464,7 @@ def test_base_url_no_trailing_slash(self, client: AsyncBrainbase) -> None: ], ids=["standard", "custom http client"], ) - def test_absolute_request_url(self, client: AsyncBrainbase) -> None: + async def test_absolute_request_url(self, client: AsyncBrainbase) -> None: request = client._build_request( FinalRequestOptions( method="post", @@ -1390,37 +1473,37 @@ def test_absolute_request_url(self, client: AsyncBrainbase) -> None: ), ) assert request.url == "https://myapi.com/foo" + await client.close() async def test_copied_client_does_not_close_http(self) -> None: - client = AsyncBrainbase(base_url=base_url, api_key=api_key, _strict_response_validation=True) - assert not client.is_closed() + test_client = AsyncBrainbase(base_url=base_url, api_key=api_key, _strict_response_validation=True) + assert not test_client.is_closed() - copied = client.copy() - assert copied is not client + copied = test_client.copy() + assert copied is not test_client del copied await asyncio.sleep(0.2) - assert not client.is_closed() + assert not test_client.is_closed() async def test_client_context_manager(self) -> None: - client = AsyncBrainbase(base_url=base_url, api_key=api_key, _strict_response_validation=True) - async with client as c2: - assert c2 is client + test_client = AsyncBrainbase(base_url=base_url, api_key=api_key, _strict_response_validation=True) + async with test_client as c2: + assert c2 is test_client assert not c2.is_closed() - assert not client.is_closed() - assert client.is_closed() + assert not test_client.is_closed() + assert test_client.is_closed() @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio - async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None: + async def test_client_response_validation_error(self, respx_mock: MockRouter, async_client: AsyncBrainbase) -> None: class Model(BaseModel): foo: str respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}})) with pytest.raises(APIResponseValidationError) as exc: - await self.client.get("/foo", cast_to=Model) + await async_client.get("/foo", cast_to=Model) assert isinstance(exc.value.__cause__, ValidationError) @@ -1431,7 +1514,6 @@ async def test_client_max_retries_validation(self) -> None: ) @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None: class Model(BaseModel): name: str @@ -1443,11 +1525,14 @@ class Model(BaseModel): with pytest.raises(APIResponseValidationError): await strict_client.get("/foo", cast_to=Model) - client = AsyncBrainbase(base_url=base_url, api_key=api_key, _strict_response_validation=False) + non_strict_client = AsyncBrainbase(base_url=base_url, api_key=api_key, _strict_response_validation=False) - response = await client.get("/foo", cast_to=Model) + response = await non_strict_client.get("/foo", cast_to=Model) assert isinstance(response, str) # type: ignore[unreachable] + await strict_client.close() + await non_strict_client.close() + @pytest.mark.parametrize( "remaining_retries,retry_after,timeout", [ @@ -1470,43 +1555,40 @@ class Model(BaseModel): ], ) @mock.patch("time.time", mock.MagicMock(return_value=1696004797)) - @pytest.mark.asyncio - async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None: - client = AsyncBrainbase(base_url=base_url, api_key=api_key, _strict_response_validation=True) - + async def test_parse_retry_after_header( + self, remaining_retries: int, retry_after: str, timeout: float, async_client: AsyncBrainbase + ) -> None: headers = httpx.Headers({"retry-after": retry_after}) options = FinalRequestOptions(method="get", url="/foo", max_retries=3) - calculated = client._calculate_retry_timeout(remaining_retries, options, headers) + calculated = async_client._calculate_retry_timeout(remaining_retries, options, headers) assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType] @mock.patch("brainbase._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: + async def test_retrying_timeout_errors_doesnt_leak( + self, respx_mock: MockRouter, async_client: AsyncBrainbase + ) -> None: respx_mock.get("/api/workers").mock(side_effect=httpx.TimeoutException("Test timeout error")) with pytest.raises(APITimeoutError): - await self.client.get( - "/api/workers", cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}} - ) + await async_client.workers.with_streaming_response.list().__aenter__() - assert _get_open_connections(self.client) == 0 + assert _get_open_connections(async_client) == 0 @mock.patch("brainbase._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: + async def test_retrying_status_errors_doesnt_leak( + self, respx_mock: MockRouter, async_client: AsyncBrainbase + ) -> None: respx_mock.get("/api/workers").mock(return_value=httpx.Response(500)) with pytest.raises(APIStatusError): - await self.client.get( - "/api/workers", cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}} - ) - - assert _get_open_connections(self.client) == 0 + await async_client.workers.with_streaming_response.list().__aenter__() + assert _get_open_connections(async_client) == 0 @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @mock.patch("brainbase._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio @pytest.mark.parametrize("failure_mode", ["status", "exception"]) async def test_retries_taken( self, @@ -1538,7 +1620,6 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @mock.patch("brainbase._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio async def test_omit_retry_count_header( self, async_client: AsyncBrainbase, failures_before_success: int, respx_mock: MockRouter ) -> None: @@ -1562,7 +1643,6 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @mock.patch("brainbase._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio async def test_overwrite_retry_count_header( self, async_client: AsyncBrainbase, failures_before_success: int, respx_mock: MockRouter ) -> None: @@ -1583,47 +1663,55 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: assert response.http_request.headers.get("x-stainless-retry-count") == "42" - def test_get_platform(self) -> None: - # A previous implementation of asyncify could leave threads unterminated when - # used with nest_asyncio. - # - # Since nest_asyncio.apply() is global and cannot be un-applied, this - # test is run in a separate process to avoid affecting other tests. - test_code = dedent(""" - import asyncio - import nest_asyncio - import threading - - from brainbase._utils import asyncify - from brainbase._base_client import get_platform - - async def test_main() -> None: - result = await asyncify(get_platform)() - print(result) - for thread in threading.enumerate(): - print(thread.name) - - nest_asyncio.apply() - asyncio.run(test_main()) - """) - with subprocess.Popen( - [sys.executable, "-c", test_code], - text=True, - ) as process: - timeout = 10 # seconds - - start_time = time.monotonic() - while True: - return_code = process.poll() - if return_code is not None: - if return_code != 0: - raise AssertionError("calling get_platform using asyncify resulted in a non-zero exit code") - - # success - break - - if time.monotonic() - start_time > timeout: - process.kill() - raise AssertionError("calling get_platform using asyncify resulted in a hung process") - - time.sleep(0.1) + async def test_get_platform(self) -> None: + platform = await asyncify(get_platform)() + assert isinstance(platform, (str, OtherPlatform)) + + async def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Test that the proxy environment variables are set correctly + monkeypatch.setenv("HTTPS_PROXY", "https://example.org") + + client = DefaultAsyncHttpxClient() + + mounts = tuple(client._mounts.items()) + assert len(mounts) == 1 + assert mounts[0][0].pattern == "https://" + + @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning") + async def test_default_client_creation(self) -> None: + # Ensure that the client can be initialized without any exceptions + DefaultAsyncHttpxClient( + verify=True, + cert=None, + trust_env=True, + http1=True, + http2=False, + limits=httpx.Limits(max_connections=100, max_keepalive_connections=20), + ) + + @pytest.mark.respx(base_url=base_url) + async def test_follow_redirects(self, respx_mock: MockRouter, async_client: AsyncBrainbase) -> None: + # Test that the default follow_redirects=True allows following redirects + respx_mock.post("/redirect").mock( + return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) + ) + respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"})) + + response = await async_client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response) + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + @pytest.mark.respx(base_url=base_url) + async def test_follow_redirects_disabled(self, respx_mock: MockRouter, async_client: AsyncBrainbase) -> None: + # Test that follow_redirects=False prevents following redirects + respx_mock.post("/redirect").mock( + return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) + ) + + with pytest.raises(APIStatusError) as exc_info: + await async_client.post( + "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response + ) + + assert exc_info.value.response.status_code == 302 + assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected" diff --git a/tests/test_models.py b/tests/test_models.py index 1063c832..0471c6c9 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, List, Union, Optional, cast +from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional, cast from datetime import datetime, timezone from typing_extensions import Literal, Annotated, TypeAliasType @@ -8,8 +8,8 @@ from pydantic import Field from brainbase._utils import PropertyInfo -from brainbase._compat import PYDANTIC_V2, parse_obj, model_dump, model_json -from brainbase._models import BaseModel, construct_type +from brainbase._compat import PYDANTIC_V1, parse_obj, model_dump, model_json +from brainbase._models import DISCRIMINATOR_CACHE, BaseModel, construct_type class BasicModel(BaseModel): @@ -294,12 +294,12 @@ class Model(BaseModel): assert cast(bool, m.foo) is True m = Model.construct(foo={"name": 3}) - if PYDANTIC_V2: - assert isinstance(m.foo, Submodel1) - assert m.foo.name == 3 # type: ignore - else: + if PYDANTIC_V1: assert isinstance(m.foo, Submodel2) assert m.foo.name == "3" + else: + assert isinstance(m.foo, Submodel1) + assert m.foo.name == 3 # type: ignore def test_list_of_unions() -> None: @@ -426,10 +426,10 @@ class Model(BaseModel): expected = datetime(2019, 12, 27, 18, 11, 19, 117000, tzinfo=timezone.utc) - if PYDANTIC_V2: - expected_json = '{"created_at":"2019-12-27T18:11:19.117000Z"}' - else: + if PYDANTIC_V1: expected_json = '{"created_at": "2019-12-27T18:11:19.117000+00:00"}' + else: + expected_json = '{"created_at":"2019-12-27T18:11:19.117000Z"}' model = Model.construct(created_at="2019-12-27T18:11:19.117Z") assert model.created_at == expected @@ -492,12 +492,15 @@ class Model(BaseModel): resource_id: Optional[str] = None m = Model.construct() + assert m.resource_id is None assert "resource_id" not in m.model_fields_set m = Model.construct(resource_id=None) + assert m.resource_id is None assert "resource_id" in m.model_fields_set m = Model.construct(resource_id="foo") + assert m.resource_id == "foo" assert "resource_id" in m.model_fields_set @@ -528,7 +531,7 @@ class Model2(BaseModel): assert m4.to_dict(mode="python") == {"created_at": datetime.fromisoformat(time_str)} assert m4.to_dict(mode="json") == {"created_at": time_str} - if not PYDANTIC_V2: + if PYDANTIC_V1: with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"): m.to_dict(warnings=False) @@ -553,7 +556,7 @@ class Model(BaseModel): assert m3.model_dump() == {"foo": None} assert m3.model_dump(exclude_none=True) == {} - if not PYDANTIC_V2: + if PYDANTIC_V1: with pytest.raises(ValueError, match="round_trip is only supported in Pydantic v2"): m.model_dump(round_trip=True) @@ -577,10 +580,10 @@ class Model(BaseModel): assert json.loads(m.to_json()) == {"FOO": "hello"} assert json.loads(m.to_json(use_api_names=False)) == {"foo": "hello"} - if PYDANTIC_V2: - assert m.to_json(indent=None) == '{"FOO":"hello"}' - else: + if PYDANTIC_V1: assert m.to_json(indent=None) == '{"FOO": "hello"}' + else: + assert m.to_json(indent=None) == '{"FOO":"hello"}' m2 = Model() assert json.loads(m2.to_json()) == {} @@ -592,7 +595,7 @@ class Model(BaseModel): assert json.loads(m3.to_json()) == {"FOO": None} assert json.loads(m3.to_json(exclude_none=True)) == {} - if not PYDANTIC_V2: + if PYDANTIC_V1: with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"): m.to_json(warnings=False) @@ -619,7 +622,7 @@ class Model(BaseModel): assert json.loads(m3.model_dump_json()) == {"foo": None} assert json.loads(m3.model_dump_json(exclude_none=True)) == {} - if not PYDANTIC_V2: + if PYDANTIC_V1: with pytest.raises(ValueError, match="round_trip is only supported in Pydantic v2"): m.model_dump_json(round_trip=True) @@ -676,12 +679,12 @@ class B(BaseModel): ) assert isinstance(m, A) assert m.type == "a" - if PYDANTIC_V2: - assert m.data == 100 # type: ignore[comparison-overlap] - else: + if PYDANTIC_V1: # pydantic v1 automatically converts inputs to strings # if the expected type is a str assert m.data == "100" + else: + assert m.data == 100 # type: ignore[comparison-overlap] def test_discriminated_unions_unknown_variant() -> None: @@ -765,12 +768,12 @@ class B(BaseModel): ) assert isinstance(m, A) assert m.foo_type == "a" - if PYDANTIC_V2: - assert m.data == 100 # type: ignore[comparison-overlap] - else: + if PYDANTIC_V1: # pydantic v1 automatically converts inputs to strings # if the expected type is a str assert m.data == "100" + else: + assert m.data == 100 # type: ignore[comparison-overlap] def test_discriminated_unions_overlapping_discriminators_invalid_data() -> None: @@ -806,7 +809,7 @@ class B(BaseModel): UnionType = cast(Any, Union[A, B]) - assert not hasattr(UnionType, "__discriminator__") + assert not DISCRIMINATOR_CACHE.get(UnionType) m = construct_type( value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")]) @@ -815,7 +818,7 @@ class B(BaseModel): assert m.type == "b" assert m.data == "foo" # type: ignore[comparison-overlap] - discriminator = UnionType.__discriminator__ + discriminator = DISCRIMINATOR_CACHE.get(UnionType) assert discriminator is not None m = construct_type( @@ -827,12 +830,12 @@ class B(BaseModel): # if the discriminator details object stays the same between invocations then # we hit the cache - assert UnionType.__discriminator__ is discriminator + assert DISCRIMINATOR_CACHE.get(UnionType) is discriminator -@pytest.mark.skipif(not PYDANTIC_V2, reason="TypeAliasType is not supported in Pydantic v1") +@pytest.mark.skipif(PYDANTIC_V1, reason="TypeAliasType is not supported in Pydantic v1") def test_type_alias_type() -> None: - Alias = TypeAliasType("Alias", str) + Alias = TypeAliasType("Alias", str) # pyright: ignore class Model(BaseModel): alias: Alias @@ -846,7 +849,7 @@ class Model(BaseModel): assert m.union == "bar" -@pytest.mark.skipif(not PYDANTIC_V2, reason="TypeAliasType is not supported in Pydantic v1") +@pytest.mark.skipif(PYDANTIC_V1, reason="TypeAliasType is not supported in Pydantic v1") def test_field_named_cls() -> None: class Model(BaseModel): cls: str @@ -854,3 +857,107 @@ class Model(BaseModel): m = construct_type(value={"cls": "foo"}, type_=Model) assert isinstance(m, Model) assert isinstance(m.cls, str) + + +def test_discriminated_union_case() -> None: + class A(BaseModel): + type: Literal["a"] + + data: bool + + class B(BaseModel): + type: Literal["b"] + + data: List[Union[A, object]] + + class ModelA(BaseModel): + type: Literal["modelA"] + + data: int + + class ModelB(BaseModel): + type: Literal["modelB"] + + required: str + + data: Union[A, B] + + # when constructing ModelA | ModelB, value data doesn't match ModelB exactly - missing `required` + m = construct_type( + value={"type": "modelB", "data": {"type": "a", "data": True}}, + type_=cast(Any, Annotated[Union[ModelA, ModelB], PropertyInfo(discriminator="type")]), + ) + + assert isinstance(m, ModelB) + + +def test_nested_discriminated_union() -> None: + class InnerType1(BaseModel): + type: Literal["type_1"] + + class InnerModel(BaseModel): + inner_value: str + + class InnerType2(BaseModel): + type: Literal["type_2"] + some_inner_model: InnerModel + + class Type1(BaseModel): + base_type: Literal["base_type_1"] + value: Annotated[ + Union[ + InnerType1, + InnerType2, + ], + PropertyInfo(discriminator="type"), + ] + + class Type2(BaseModel): + base_type: Literal["base_type_2"] + + T = Annotated[ + Union[ + Type1, + Type2, + ], + PropertyInfo(discriminator="base_type"), + ] + + model = construct_type( + type_=T, + value={ + "base_type": "base_type_1", + "value": { + "type": "type_2", + }, + }, + ) + assert isinstance(model, Type1) + assert isinstance(model.value, InnerType2) + + +@pytest.mark.skipif(PYDANTIC_V1, reason="this is only supported in pydantic v2 for now") +def test_extra_properties() -> None: + class Item(BaseModel): + prop: int + + class Model(BaseModel): + __pydantic_extra__: Dict[str, Item] = Field(init=False) # pyright: ignore[reportIncompatibleVariableOverride] + + other: str + + if TYPE_CHECKING: + + def __getattr__(self, attr: str) -> Item: ... + + model = construct_type( + type_=Model, + value={ + "a": {"prop": 1}, + "other": "foo", + }, + ) + assert isinstance(model, Model) + assert model.a.prop == 1 + assert isinstance(model.a, Item) + assert model.other == "foo" diff --git a/tests/test_transform.py b/tests/test_transform.py index 2293b288..9023ce5d 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -2,20 +2,20 @@ import io import pathlib -from typing import Any, List, Union, TypeVar, Iterable, Optional, cast +from typing import Any, Dict, List, Union, TypeVar, Iterable, Optional, cast from datetime import date, datetime from typing_extensions import Required, Annotated, TypedDict import pytest -from brainbase._types import Base64FileInput +from brainbase._types import Base64FileInput, omit, not_given from brainbase._utils import ( PropertyInfo, transform as _transform, parse_datetime, async_transform as _async_transform, ) -from brainbase._compat import PYDANTIC_V2 +from brainbase._compat import PYDANTIC_V1 from brainbase._models import BaseModel _T = TypeVar("_T") @@ -189,7 +189,7 @@ class DateModel(BaseModel): @pytest.mark.asyncio async def test_iso8601_format(use_async: bool) -> None: dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00") - tz = "Z" if PYDANTIC_V2 else "+00:00" + tz = "+00:00" if PYDANTIC_V1 else "Z" assert await transform({"foo": dt}, DatetimeDict, use_async) == {"foo": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap] assert await transform(DatetimeModel(foo=dt), Any, use_async) == {"foo": "2023-02-23T14:16:36.337692" + tz} # type: ignore[comparison-overlap] @@ -297,11 +297,11 @@ async def test_pydantic_unknown_field(use_async: bool) -> None: @pytest.mark.asyncio async def test_pydantic_mismatched_types(use_async: bool) -> None: model = MyModel.construct(foo=True) - if PYDANTIC_V2: + if PYDANTIC_V1: + params = await transform(model, Any, use_async) + else: with pytest.warns(UserWarning): params = await transform(model, Any, use_async) - else: - params = await transform(model, Any, use_async) assert cast(Any, params) == {"foo": True} @@ -309,11 +309,11 @@ async def test_pydantic_mismatched_types(use_async: bool) -> None: @pytest.mark.asyncio async def test_pydantic_mismatched_object_type(use_async: bool) -> None: model = MyModel.construct(foo=MyModel.construct(hello="world")) - if PYDANTIC_V2: + if PYDANTIC_V1: + params = await transform(model, Any, use_async) + else: with pytest.warns(UserWarning): params = await transform(model, Any, use_async) - else: - params = await transform(model, Any, use_async) assert cast(Any, params) == {"foo": {"hello": "world"}} @@ -388,6 +388,15 @@ def my_iter() -> Iterable[Baz8]: } +@parametrize +@pytest.mark.asyncio +async def test_dictionary_items(use_async: bool) -> None: + class DictItems(TypedDict): + foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")] + + assert await transform({"foo": {"foo_baz": "bar"}}, Dict[str, DictItems], use_async) == {"foo": {"fooBaz": "bar"}} + + class TypedDictIterableUnionStr(TypedDict): foo: Annotated[Union[str, Iterable[Baz8]], PropertyInfo(alias="FOO")] @@ -423,3 +432,29 @@ async def test_base64_file_input(use_async: bool) -> None: assert await transform({"foo": io.BytesIO(b"Hello, world!")}, TypedDictBase64Input, use_async) == { "foo": "SGVsbG8sIHdvcmxkIQ==" } # type: ignore[comparison-overlap] + + +@parametrize +@pytest.mark.asyncio +async def test_transform_skipping(use_async: bool) -> None: + # lists of ints are left as-is + data = [1, 2, 3] + assert await transform(data, List[int], use_async) is data + + # iterables of ints are converted to a list + data = iter([1, 2, 3]) + assert await transform(data, Iterable[int], use_async) == [1, 2, 3] + + +@parametrize +@pytest.mark.asyncio +async def test_strips_notgiven(use_async: bool) -> None: + assert await transform({"foo_bar": "bar"}, Foo1, use_async) == {"fooBar": "bar"} + assert await transform({"foo_bar": not_given}, Foo1, use_async) == {} + + +@parametrize +@pytest.mark.asyncio +async def test_strips_omit(use_async: bool) -> None: + assert await transform({"foo_bar": "bar"}, Foo1, use_async) == {"fooBar": "bar"} + assert await transform({"foo_bar": omit}, Foo1, use_async) == {} diff --git a/tests/test_utils/test_datetime_parse.py b/tests/test_utils/test_datetime_parse.py new file mode 100644 index 00000000..be6c09ec --- /dev/null +++ b/tests/test_utils/test_datetime_parse.py @@ -0,0 +1,110 @@ +""" +Copied from https://github.com/pydantic/pydantic/blob/v1.10.22/tests/test_datetime_parse.py +with modifications so it works without pydantic v1 imports. +""" + +from typing import Type, Union +from datetime import date, datetime, timezone, timedelta + +import pytest + +from brainbase._utils import parse_date, parse_datetime + + +def create_tz(minutes: int) -> timezone: + return timezone(timedelta(minutes=minutes)) + + +@pytest.mark.parametrize( + "value,result", + [ + # Valid inputs + ("1494012444.883309", date(2017, 5, 5)), + (b"1494012444.883309", date(2017, 5, 5)), + (1_494_012_444.883_309, date(2017, 5, 5)), + ("1494012444", date(2017, 5, 5)), + (1_494_012_444, date(2017, 5, 5)), + (0, date(1970, 1, 1)), + ("2012-04-23", date(2012, 4, 23)), + (b"2012-04-23", date(2012, 4, 23)), + ("2012-4-9", date(2012, 4, 9)), + (date(2012, 4, 9), date(2012, 4, 9)), + (datetime(2012, 4, 9, 12, 15), date(2012, 4, 9)), + # Invalid inputs + ("x20120423", ValueError), + ("2012-04-56", ValueError), + (19_999_999_999, date(2603, 10, 11)), # just before watershed + (20_000_000_001, date(1970, 8, 20)), # just after watershed + (1_549_316_052, date(2019, 2, 4)), # nowish in s + (1_549_316_052_104, date(2019, 2, 4)), # nowish in ms + (1_549_316_052_104_324, date(2019, 2, 4)), # nowish in μs + (1_549_316_052_104_324_096, date(2019, 2, 4)), # nowish in ns + ("infinity", date(9999, 12, 31)), + ("inf", date(9999, 12, 31)), + (float("inf"), date(9999, 12, 31)), + ("infinity ", date(9999, 12, 31)), + (int("1" + "0" * 100), date(9999, 12, 31)), + (1e1000, date(9999, 12, 31)), + ("-infinity", date(1, 1, 1)), + ("-inf", date(1, 1, 1)), + ("nan", ValueError), + ], +) +def test_date_parsing(value: Union[str, bytes, int, float], result: Union[date, Type[Exception]]) -> None: + if type(result) == type and issubclass(result, Exception): # pyright: ignore[reportUnnecessaryIsInstance] + with pytest.raises(result): + parse_date(value) + else: + assert parse_date(value) == result + + +@pytest.mark.parametrize( + "value,result", + [ + # Valid inputs + # values in seconds + ("1494012444.883309", datetime(2017, 5, 5, 19, 27, 24, 883_309, tzinfo=timezone.utc)), + (1_494_012_444.883_309, datetime(2017, 5, 5, 19, 27, 24, 883_309, tzinfo=timezone.utc)), + ("1494012444", datetime(2017, 5, 5, 19, 27, 24, tzinfo=timezone.utc)), + (b"1494012444", datetime(2017, 5, 5, 19, 27, 24, tzinfo=timezone.utc)), + (1_494_012_444, datetime(2017, 5, 5, 19, 27, 24, tzinfo=timezone.utc)), + # values in ms + ("1494012444000.883309", datetime(2017, 5, 5, 19, 27, 24, 883, tzinfo=timezone.utc)), + ("-1494012444000.883309", datetime(1922, 8, 29, 4, 32, 35, 999117, tzinfo=timezone.utc)), + (1_494_012_444_000, datetime(2017, 5, 5, 19, 27, 24, tzinfo=timezone.utc)), + ("2012-04-23T09:15:00", datetime(2012, 4, 23, 9, 15)), + ("2012-4-9 4:8:16", datetime(2012, 4, 9, 4, 8, 16)), + ("2012-04-23T09:15:00Z", datetime(2012, 4, 23, 9, 15, 0, 0, timezone.utc)), + ("2012-4-9 4:8:16-0320", datetime(2012, 4, 9, 4, 8, 16, 0, create_tz(-200))), + ("2012-04-23T10:20:30.400+02:30", datetime(2012, 4, 23, 10, 20, 30, 400_000, create_tz(150))), + ("2012-04-23T10:20:30.400+02", datetime(2012, 4, 23, 10, 20, 30, 400_000, create_tz(120))), + ("2012-04-23T10:20:30.400-02", datetime(2012, 4, 23, 10, 20, 30, 400_000, create_tz(-120))), + (b"2012-04-23T10:20:30.400-02", datetime(2012, 4, 23, 10, 20, 30, 400_000, create_tz(-120))), + (datetime(2017, 5, 5), datetime(2017, 5, 5)), + (0, datetime(1970, 1, 1, 0, 0, 0, tzinfo=timezone.utc)), + # Invalid inputs + ("x20120423091500", ValueError), + ("2012-04-56T09:15:90", ValueError), + ("2012-04-23T11:05:00-25:00", ValueError), + (19_999_999_999, datetime(2603, 10, 11, 11, 33, 19, tzinfo=timezone.utc)), # just before watershed + (20_000_000_001, datetime(1970, 8, 20, 11, 33, 20, 1000, tzinfo=timezone.utc)), # just after watershed + (1_549_316_052, datetime(2019, 2, 4, 21, 34, 12, 0, tzinfo=timezone.utc)), # nowish in s + (1_549_316_052_104, datetime(2019, 2, 4, 21, 34, 12, 104_000, tzinfo=timezone.utc)), # nowish in ms + (1_549_316_052_104_324, datetime(2019, 2, 4, 21, 34, 12, 104_324, tzinfo=timezone.utc)), # nowish in μs + (1_549_316_052_104_324_096, datetime(2019, 2, 4, 21, 34, 12, 104_324, tzinfo=timezone.utc)), # nowish in ns + ("infinity", datetime(9999, 12, 31, 23, 59, 59, 999999)), + ("inf", datetime(9999, 12, 31, 23, 59, 59, 999999)), + ("inf ", datetime(9999, 12, 31, 23, 59, 59, 999999)), + (1e50, datetime(9999, 12, 31, 23, 59, 59, 999999)), + (float("inf"), datetime(9999, 12, 31, 23, 59, 59, 999999)), + ("-infinity", datetime(1, 1, 1, 0, 0)), + ("-inf", datetime(1, 1, 1, 0, 0)), + ("nan", ValueError), + ], +) +def test_datetime_parsing(value: Union[str, bytes, int, float], result: Union[datetime, Type[Exception]]) -> None: + if type(result) == type and issubclass(result, Exception): # pyright: ignore[reportUnnecessaryIsInstance] + with pytest.raises(result): + parse_datetime(value) + else: + assert parse_datetime(value) == result diff --git a/tests/test_utils/test_proxy.py b/tests/test_utils/test_proxy.py index d79d1bb5..92ae06db 100644 --- a/tests/test_utils/test_proxy.py +++ b/tests/test_utils/test_proxy.py @@ -21,3 +21,14 @@ def test_recursive_proxy() -> None: assert dir(proxy) == [] assert type(proxy).__name__ == "RecursiveLazyProxy" assert type(operator.attrgetter("name.foo.bar.baz")(proxy)).__name__ == "RecursiveLazyProxy" + + +def test_isinstance_does_not_error() -> None: + class AlwaysErrorProxy(LazyProxy[Any]): + @override + def __load__(self) -> Any: + raise RuntimeError("Mocking missing dependency") + + proxy = AlwaysErrorProxy() + assert not isinstance(proxy, dict) + assert isinstance(proxy, LazyProxy) diff --git a/tests/utils.py b/tests/utils.py index fa696b4a..5dbc46dc 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,7 +4,7 @@ import inspect import traceback import contextlib -from typing import Any, TypeVar, Iterator, cast +from typing import Any, TypeVar, Iterator, Sequence, cast from datetime import date, datetime from typing_extensions import Literal, get_args, get_origin, assert_type @@ -15,10 +15,11 @@ is_list_type, is_union_type, extract_type_arg, + is_sequence_type, is_annotated_type, is_type_alias_type, ) -from brainbase._compat import PYDANTIC_V2, field_outer_type, get_model_fields +from brainbase._compat import PYDANTIC_V1, field_outer_type, get_model_fields from brainbase._models import BaseModel BaseModelT = TypeVar("BaseModelT", bound=BaseModel) @@ -27,12 +28,12 @@ def assert_matches_model(model: type[BaseModelT], value: BaseModelT, *, path: list[str]) -> bool: for name, field in get_model_fields(model).items(): field_value = getattr(value, name) - if PYDANTIC_V2: - allow_none = False - else: + if PYDANTIC_V1: # in v1 nullability was structured differently # https://docs.pydantic.dev/2.0/migration/#required-optional-and-nullable-fields allow_none = getattr(field, "allow_none", False) + else: + allow_none = False assert_matches_type( field_outer_type(field), @@ -71,6 +72,13 @@ def assert_matches_type( if is_list_type(type_): return _assert_list_type(type_, value) + if is_sequence_type(type_): + assert isinstance(value, Sequence) + inner_type = get_args(type_)[0] + for entry in value: # type: ignore + assert_type(inner_type, entry) # type: ignore + return + if origin == str: assert isinstance(value, str) elif origin == int: