diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile
index ac9a2e75..ff261bad 100644
--- a/.devcontainer/Dockerfile
+++ b/.devcontainer/Dockerfile
@@ -3,7 +3,7 @@ FROM mcr.microsoft.com/vscode/devcontainers/python:0-${VARIANT}
USER vscode
-RUN curl -sSf https://rye.astral.sh/get | RYE_VERSION="0.35.0" RYE_INSTALL_OPTION="--yes" bash
+RUN curl -sSf https://rye.astral.sh/get | RYE_VERSION="0.44.0" RYE_INSTALL_OPTION="--yes" bash
ENV PATH=/home/vscode/.rye/shims:$PATH
-RUN echo "[[ -d .venv ]] && source .venv/bin/activate" >> /home/vscode/.bashrc
+RUN echo "[[ -d .venv ]] && source .venv/bin/activate || export PATH=\$PATH" >> /home/vscode/.bashrc
diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json
index bbeb30b1..c17fdc16 100644
--- a/.devcontainer/devcontainer.json
+++ b/.devcontainer/devcontainer.json
@@ -24,6 +24,9 @@
}
}
}
+ },
+ "features": {
+ "ghcr.io/devcontainers/features/node:1": {}
}
// Features to add to the dev container. More info: https://containers.dev/features.
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index c8a8a4f7..5eaaecd9 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -1,18 +1,23 @@
name: CI
on:
push:
- branches:
- - main
+ branches-ignore:
+ - 'generated'
+ - 'codegen/**'
+ - 'integrated/**'
+ - 'stl-preview-head/**'
+ - 'stl-preview-base/**'
pull_request:
- branches:
- - main
- - next
+ branches-ignore:
+ - 'stl-preview-head/**'
+ - 'stl-preview-base/**'
jobs:
lint:
+ timeout-minutes: 10
name: lint
- runs-on: ubuntu-latest
-
+ runs-on: ${{ github.repository == 'stainless-sdks/brainbase-python' && 'depot-ubuntu-24.04' || 'ubuntu-latest' }}
+ if: github.event_name == 'push' || github.event.pull_request.head.repo.fork
steps:
- uses: actions/checkout@v4
@@ -21,7 +26,7 @@ jobs:
curl -sSf https://rye.astral.sh/get | bash
echo "$HOME/.rye/shims" >> $GITHUB_PATH
env:
- RYE_VERSION: '0.35.0'
+ RYE_VERSION: '0.44.0'
RYE_INSTALL_OPTION: '--yes'
- name: Install dependencies
@@ -30,10 +35,51 @@ jobs:
- name: Run lints
run: ./scripts/lint
+ build:
+ if: github.event_name == 'push' || github.event.pull_request.head.repo.fork
+ timeout-minutes: 10
+ name: build
+ permissions:
+ contents: read
+ id-token: write
+ runs-on: ${{ github.repository == 'stainless-sdks/brainbase-python' && 'depot-ubuntu-24.04' || 'ubuntu-latest' }}
+ steps:
+ - uses: actions/checkout@v4
+
+ - name: Install Rye
+ run: |
+ curl -sSf https://rye.astral.sh/get | bash
+ echo "$HOME/.rye/shims" >> $GITHUB_PATH
+ env:
+ RYE_VERSION: '0.44.0'
+ RYE_INSTALL_OPTION: '--yes'
+
+ - name: Install dependencies
+ run: rye sync --all-features
+
+ - name: Run build
+ run: rye build
+
+ - name: Get GitHub OIDC Token
+ if: github.repository == 'stainless-sdks/brainbase-python'
+ id: github-oidc
+ uses: actions/github-script@v6
+ with:
+ script: core.setOutput('github_token', await core.getIDToken());
+
+ - name: Upload tarball
+ if: github.repository == 'stainless-sdks/brainbase-python'
+ env:
+ URL: https://pkg.stainless.com/s
+ AUTH: ${{ steps.github-oidc.outputs.github_token }}
+ SHA: ${{ github.sha }}
+ run: ./scripts/utils/upload-artifact.sh
+
test:
+ timeout-minutes: 10
name: test
- runs-on: ubuntu-latest
-
+ runs-on: ${{ github.repository == 'stainless-sdks/brainbase-python' && 'depot-ubuntu-24.04' || 'ubuntu-latest' }}
+ if: github.event_name == 'push' || github.event.pull_request.head.repo.fork
steps:
- uses: actions/checkout@v4
@@ -42,7 +88,7 @@ jobs:
curl -sSf https://rye.astral.sh/get | bash
echo "$HOME/.rye/shims" >> $GITHUB_PATH
env:
- RYE_VERSION: '0.35.0'
+ RYE_VERSION: '0.44.0'
RYE_INSTALL_OPTION: '--yes'
- name: Bootstrap
diff --git a/.github/workflows/publish-pypi.yml b/.github/workflows/publish-pypi.yml
index 3fbc99f8..ec603561 100644
--- a/.github/workflows/publish-pypi.yml
+++ b/.github/workflows/publish-pypi.yml
@@ -21,7 +21,7 @@ jobs:
curl -sSf https://rye.astral.sh/get | bash
echo "$HOME/.rye/shims" >> $GITHUB_PATH
env:
- RYE_VERSION: '0.35.0'
+ RYE_VERSION: '0.44.0'
RYE_INSTALL_OPTION: '--yes'
- name: Publish to PyPI
diff --git a/.gitignore b/.gitignore
index 87797408..95ceb189 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,5 +1,4 @@
.prism.log
-.vscode
_dev
__pycache__
diff --git a/.release-please-manifest.json b/.release-please-manifest.json
index 127ac87b..3b4c2d4b 100644
--- a/.release-please-manifest.json
+++ b/.release-please-manifest.json
@@ -1,3 +1,3 @@
{
- ".": "4.0.0"
+ ".": "4.1.0"
}
\ No newline at end of file
diff --git a/.stats.yml b/.stats.yml
index 774ad327..087e4350 100644
--- a/.stats.yml
+++ b/.stats.yml
@@ -1,2 +1,4 @@
configured_endpoints: 15
openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/brainbase-egrigokhan%2Fbrainbase-ab4ce60666d2503f2b7028d55b9f75cc42a76a668cda26576e91b851ea650b0b.yml
+openapi_spec_hash: ec07d4f39ed4cb03f93255b680ca2f35
+config_hash: 6f322d88e08375a924b420a5ee9f269c
diff --git a/.vscode/settings.json b/.vscode/settings.json
new file mode 100644
index 00000000..5b010307
--- /dev/null
+++ b/.vscode/settings.json
@@ -0,0 +1,3 @@
+{
+ "python.analysis.importFormat": "relative",
+}
diff --git a/CHANGELOG.md b/CHANGELOG.md
index aab837a5..926a771f 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,129 @@
# Changelog
+## 4.1.0 (2026-01-06)
+
+Full Changelog: [v4.0.0...v4.1.0](https://github.com/BrainbaseHQ/brainbase-python-sdk/compare/v4.0.0...v4.1.0)
+
+### Features
+
+* **api:** update via SDK Studio ([#61](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/61)) ([9c0c551](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/9c0c551d114f385be676f42aab844de1309b9406))
+* clean up environment call outs ([696f18b](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/696f18b140f81774258c1bf92e36e070fbf65c7e))
+* **client:** add follow_redirects request option ([6eb41c9](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/6eb41c9ae22aeeae7139e4cdc0bb9e0f34bb37ff))
+* **client:** add support for aiohttp ([6f6ddd9](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/6f6ddd94c735b6bcf82920259c45f2a112f12507))
+* **client:** allow passing `NotGiven` for body ([#67](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/67)) ([3ad7f25](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/3ad7f25e58b4823792d6475c7c14a3700273dd13))
+* **client:** send `X-Stainless-Read-Timeout` header ([#63](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/63)) ([a594c75](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/a594c7501a919a038a7b4d08754ee837acd45b89))
+* **client:** support file upload requests ([fde965a](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/fde965a36bf905c2e8ee5e29013141fca52d6e92))
+* improve future compat with pydantic v3 ([bccbddf](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/bccbddf086f9791bcd117701a886c16306c4c4f4))
+* **types:** replace List[str] with SequenceNotStr in params ([0578887](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/0578887c11b95955522042771b4680d5a9f4f6b9))
+
+
+### Bug Fixes
+
+* asyncify on non-asyncio runtimes ([#66](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/66)) ([ca310cd](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/ca310cd8beb201d9ad2c66ab13c1e0ed605e6a91))
+* avoid newer type syntax ([db10820](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/db108205529a83cadc51c9506fc2c7a8c372d964))
+* **ci:** correct conditional ([66eb0ef](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/66eb0ef07df1d265cbb31619ec71f2e097c42f4f))
+* **ci:** ensure pip is always available ([#78](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/78)) ([d3d295a](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/d3d295a6a4007e2504000988e6b997a5a96be28b))
+* **ci:** release-doctor — report correct token name ([65cdccf](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/65cdccfa4575fb71d6c4947ba13e210537ec181e))
+* **ci:** remove publishing patch ([#79](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/79)) ([493f504](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/493f504df7150a24d8373065b217c8b69320b432))
+* **client:** close streams without requiring full consumption ([e783145](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/e783145b013ec8af7142e357f9ec4d9d7ddc99b1))
+* **client:** correctly parse binary response | stream ([3924997](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/392499744e9c7f7b6d770f243ae700d772663fad))
+* **client:** don't send Content-Type header on GET requests ([6bcbc4e](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/6bcbc4e13f22a42e2a69d7f62c8148430bbd88d9))
+* **client:** mark some request bodies as optional ([3ad7f25](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/3ad7f25e58b4823792d6475c7c14a3700273dd13))
+* compat with Python 3.14 ([f7d6103](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/f7d6103a91fad3aa7aba7eb7174afe1b2cfea3b9))
+* **compat:** update signatures of `model_dump` and `model_dump_json` for Pydantic v1 ([898ca7b](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/898ca7bd4cf577eeeab764fbcdd9a77518b44587))
+* ensure streams are always closed ([f89af68](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/f89af6811daed337f503b3c3ad4cb93ab8b38936))
+* **package:** support direct resource imports ([ad2d130](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/ad2d130ead5da49f75340d1c642a00854f9d29b6))
+* **parsing:** correctly handle nested discriminated unions ([cd511d2](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/cd511d2d51ded4d55af012cdc475c2aa79be5785))
+* **parsing:** ignore empty metadata ([bdd8ead](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/bdd8eadd20c5e4f34b85c3b8e922831f5f1074ab))
+* **parsing:** parse extra field types ([470d8a8](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/470d8a8851db80f9c3720320a5f51ef214504c52))
+* **perf:** optimize some hot paths ([7cc4937](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/7cc4937c3882b956356579443e99826de72a4025))
+* **perf:** skip traversing types for NotGiven values ([38509ba](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/38509baa0c8b7fe6cd68ee34fe769ae1c0d95a5c))
+* **pydantic v1:** more robust ModelField.annotation check ([3dc3480](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/3dc3480e627d3199117ef6a2b869d413f6408f7b))
+* **tests:** fix: tests which call HTTP endpoints directly with the example parameters ([539215f](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/539215ff08dfbe906fad296b7a09fbd3a2bdb5bf))
+* **types:** allow pyright to infer TypedDict types within SequenceNotStr ([84b4806](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/84b4806cef26c829d8f2cea180d7db87be025788))
+* **types:** handle more discriminated union shapes ([#77](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/77)) ([8b6dcf0](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/8b6dcf01f11bed61a53a66b45ab3ae255deb82e8))
+* use async_to_httpx_files in patch method ([6bfd0c0](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/6bfd0c0e4e37cf063c32a82525914a4e6da16546))
+
+
+### Chores
+
+* add Python 3.14 classifier and testing ([c172e9c](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/c172e9c16b4ca6fd3c03631a9f58c85c6132de7d))
+* broadly detect json family of content-type headers ([febefbc](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/febefbc87ca1d2bbfadc9932400b5ae42644f49a))
+* bump `httpx-aiohttp` version to 0.1.9 ([7aeb4c8](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/7aeb4c87f034ee7f506a64adf52a349efb343cf5))
+* **ci:** add timeout thresholds for CI jobs ([d5cbcd0](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/d5cbcd06c3d27b71601e34e9ea8fb2a5461ec37d))
+* **ci:** change upload type ([d7e4405](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/d7e4405ad6aa676d5f21fe603b705f2bf7d36496))
+* **ci:** enable for pull requests ([1ea6fbc](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/1ea6fbc396ca2c40b1233e7e07ab81b6f1b19620))
+* **ci:** fix installation instructions ([6291f4a](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/6291f4ad4b3f4ae3624f3ef3dafc3ef5c9ae4ced))
+* **ci:** only run for pushes and fork pull requests ([5d00f3e](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/5d00f3ed62f35deb17229244613a788ad45bbdc7))
+* **ci:** only use depot for staging repos ([19ee773](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/19ee77377796949bd043a65b8e6934a60a9aa455))
+* **ci:** upload sdks to package manager ([598ec7e](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/598ec7e8bf03284b8db9e94049d5e42a1adef26b))
+* **client:** minor internal fixes ([1e29d3b](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/1e29d3b13de115b6047b3a712ceeac69f77ba51f))
+* **deps:** mypy 1.18.1 has a regression, pin to 1.17 ([fd940b6](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/fd940b62669bcb40a24c2be17e5ab75d116c4ce2))
+* do not install brew dependencies in ./scripts/bootstrap by default ([67c48f0](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/67c48f09376cd85fbb5dca63dc47d83c106e7be6))
+* **docs:** grammar improvements ([540e711](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/540e711335d2757eb6c76138bb5cc8fceae04ff3))
+* **docs:** remove reference to rye shell ([ec32daa](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/ec32daa0c329830982eabdffde4356b3dafdfb41))
+* **docs:** update client docstring ([#71](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/71)) ([b41543a](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/b41543a89e16e504054f95ef8987e12bef2da3a2))
+* **docs:** use environment variables for authentication in code snippets ([ff7f0ef](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/ff7f0ef0716e3bbe9973b18a4a69b5608376f701))
+* fix typos ([#80](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/80)) ([c1576cc](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/c1576ccb8e5592b395e435cc949ba7c077211b67))
+* **internal/tests:** avoid race condition with implicit client cleanup ([d3a5435](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/d3a5435f4e7c9609de5141ec44639b109758a49b))
+* **internal:** add `--fix` argument to lint script ([0b1e68e](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/0b1e68e074af9dc9a8e60a539c964433c66c3887))
+* **internal:** add missing files argument to base client ([5547d1b](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/5547d1ba450d6ccb13963d24699f89f487e7dc78))
+* **internal:** add Sequence related utils ([5f815ef](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/5f815ef63af5a60d77c4f792f3aaf42b317be58f))
+* **internal:** avoid errors for isinstance checks on proxies ([5dc0949](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/5dc09490c13d71ac41d430d78e9c32f8f59c330e))
+* **internal:** base client updates ([0ac179a](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/0ac179a512ed4773d3717a4226c7f041c67eab17))
+* **internal:** bump pinned h11 dep ([6cafe07](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/6cafe072603e69522dd754cf2d3f2a59b8f49271))
+* **internal:** bump pyright version ([81c2baf](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/81c2bafbaab324b6e4c85c1e193148fbe2c88525))
+* **internal:** bump rye to 0.44.0 ([#76](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/76)) ([21a20b3](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/21a20b383e1ce4b072bdbcdefc5774f9e2ba21f4))
+* **internal:** change ci workflow machines ([656643d](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/656643d1bd5e3bba6a6997ac3ecbce69d0aaf4cc))
+* **internal:** codegen related update ([0b4efb7](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/0b4efb79489847366f43e83b5be51c68eddac2bd))
+* **internal:** codegen related update ([2f8fbc4](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/2f8fbc4b1ebdcf56188bb58f763e2fb9fd8c202b))
+* **internal:** codegen related update ([d6d5a1d](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/d6d5a1d5cf3fec7ee383f25b8927d773a20230ca))
+* **internal:** codegen related update ([a145cee](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/a145cee50fe2c30203c1beea85b9e87e58339103))
+* **internal:** codegen related update ([#75](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/75)) ([db19786](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/db197864c745b7235e546f349e0a24a7c8cd9801))
+* **internal:** detect missing future annotations with ruff ([23b94ed](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/23b94edc4b2f2e1ab1b8ed02881296a185b3ccac))
+* **internal:** expand CI branch coverage ([7fd1145](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/7fd1145852e5f82cd271dcb4432e2474e1cbd9d4))
+* **internal:** fix devcontainers setup ([#68](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/68)) ([97b7254](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/97b725436eaca395261fa04119202ccd938e2edd))
+* **internal:** fix list file params ([1b5e333](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/1b5e333c271ee804b2fe7738dfb2ee7bd0044c9a))
+* **internal:** fix ruff target version ([1437b86](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/1437b86f88760d95f63a3cd9e5e4e3d1b9cc1673))
+* **internal:** fix type traversing dictionary params ([#64](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/64)) ([1322c80](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/1322c808afea9e8d17ac958e289850115c8d0fe8))
+* **internal:** grammar fix (it's -> its) ([8b1cdb7](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/8b1cdb737ecd09456400b79a03d1aa2d02ab8e2b))
+* **internal:** import reformatting ([8a3f6f0](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/8a3f6f0fc98525d48998f38b6bafc6b78bbea73d))
+* **internal:** minor type handling changes ([#65](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/65)) ([7e69125](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/7e691251edd799b8ffd067792c0833f99a6906bb))
+* **internal:** move mypy configurations to `pyproject.toml` file ([f87b268](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/f87b2684a44a8c6112fa30fd935e42d9d532fdbf))
+* **internal:** properly set __pydantic_private__ ([#69](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/69)) ([bc25b84](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/bc25b84c059996c26d435aa61da24f684b25f2e8))
+* **internal:** reduce CI branch coverage ([2492996](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/249299619c40b6abeef1332f3c3e33e436fe5da5))
+* **internal:** refactor retries to not use recursion ([055e329](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/055e329cdf3927632ab1e4b128c2ed5ce0333e8f))
+* **internal:** remove extra empty newlines ([#74](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/74)) ([3d90dff](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/3d90dff0832619c6495bbf7f82cc74ac9aac46c9))
+* **internal:** remove trailing character ([#81](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/81)) ([4cfa80b](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/4cfa80bf1f6a9f24ae0a7244b3ae5a248e131edc))
+* **internal:** remove unused http client options forwarding ([#72](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/72)) ([69a44e3](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/69a44e38bcfb88f523e036d4f434942a96397061))
+* **internal:** slight transform perf improvement ([#82](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/82)) ([5498eaf](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/5498eaf9154b388668be1a516d3597d27b74cc7c))
+* **internal:** update comment in script ([103820e](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/103820ebecd7a1dd118282d8becf5db502a1ec0d))
+* **internal:** update conftest.py ([83531b4](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/83531b46d24f38f1f7c6a3ca13ddbb4e5d16590b))
+* **internal:** update models test ([421a2b5](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/421a2b57d0973344ae3d86ec0f989ca655490970))
+* **internal:** update pydantic dependency ([5dd09a4](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/5dd09a48ae3a7ef8efe86304ea1d84acd0bf7d39))
+* **internal:** update pyright exclude list ([f306088](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/f3060880c80762252f85cb071eae6b7c1263265f))
+* **internal:** update pyright settings ([2d267e1](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/2d267e1d9a8fee34b0dde0ca9274afac975a1648))
+* **package:** drop Python 3.8 support ([95ca18d](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/95ca18d4426eac5714e12b54f5f0fb181aaa85e6))
+* **package:** mark python 3.13 as supported ([06ad64f](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/06ad64facff9ca27b1049dc9bc2a41aa1872b7ee))
+* **project:** add settings file for vscode ([6754b39](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/6754b3937cfa144183e8d4633efdd19e2d0e1413))
+* **readme:** fix version rendering on pypi ([b1e9e51](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/b1e9e51bb3adf1c0f2387b4aff1804ff6e40e346))
+* **readme:** update badges ([f3f214f](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/f3f214f315d721421f383ba4b5e63ddd828c0b59))
+* speedup initial import ([3b3f4e6](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/3b3f4e63e4488509e60062a71cf5f3f569b4e529))
+* **tests:** add tests for httpx client instantiation & proxies ([5e172cd](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/5e172cd1be541e04ab62ef1c770b5d1e9b7fc44d))
+* **tests:** run tests in parallel ([497b381](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/497b3816f4167e344dbbad4a2910236436641777))
+* **tests:** simplify `get_platform` test ([ef07d85](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/ef07d85cf3d9001104e57e9e1a266aa65f6852e5))
+* **tests:** skip some failing tests on the latest python versions ([39d037c](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/39d037c9713c26103a549a3b0c03003a93ef3fc4))
+* **types:** change optional parameter type from NotGiven to Omit ([14cb9a9](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/14cb9a994e36c901e9b73ca1004a40cb30f87786))
+* update @stainless-api/prism-cli to v5.15.0 ([d766a01](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/d766a015e4fd9a56336bf37be4cce0eb08361a20))
+* update github action ([d0f9d9e](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/d0f9d9e2bd470f7ee9f2c29cf514d9a9da093c0b))
+* update lockfile ([77726a0](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/77726a0c2aeb678ef7995da7be5aa667388917ca))
+
+
+### Documentation
+
+* **client:** fix httpx.Timeout documentation reference ([3341b85](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/3341b85797cfcc35e3e8c1f3a151176c85587ec8))
+* update URLs from stainlessapi.com to stainless.com ([#70](https://github.com/BrainbaseHQ/brainbase-python-sdk/issues/70)) ([08062c4](https://github.com/BrainbaseHQ/brainbase-python-sdk/commit/08062c48906749ce529d5ea1fcdfed3b1b18412c))
+
## 4.0.0 (2025-02-04)
Full Changelog: [v3.0.0...v4.0.0](https://github.com/BrainbaseHQ/brainbase-python-sdk/compare/v3.0.0...v4.0.0)
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 2cf8f44a..37d38185 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -17,8 +17,7 @@ $ rye sync --all-features
You can then run scripts using `rye run python script.py` or by activating the virtual environment:
```sh
-$ rye shell
-# or manually activate - https://docs.python.org/3/library/venv.html#how-venvs-work
+# Activate the virtual environment - https://docs.python.org/3/library/venv.html#how-venvs-work
$ source .venv/bin/activate
# now you can omit the `rye run` prefix
diff --git a/LICENSE b/LICENSE
index 62446133..2ab165fa 100644
--- a/LICENSE
+++ b/LICENSE
@@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier
identification within third-party archives.
- Copyright 2025 Brainbase
+ Copyright 2026 Brainbase
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
diff --git a/README.md b/README.md
index 33ff2fa4..33243d9c 100644
--- a/README.md
+++ b/README.md
@@ -1,16 +1,17 @@
# Brainbase Python API library
-[](https://pypi.org/project/brainbase-labs/)
+
+[)](https://pypi.org/project/brainbase-labs/)
-The Brainbase Python library provides convenient access to the Brainbase REST API from any Python 3.8+
+The Brainbase Python library provides convenient access to the Brainbase REST API from any Python 3.9+
application. The library includes type definitions for all request params and response fields,
and offers both synchronous and asynchronous clients powered by [httpx](https://github.com/encode/httpx).
-It is generated with [Stainless](https://www.stainlessapi.com/).
+It is generated with [Stainless](https://www.stainless.com/).
## Documentation
-The REST API documentation can be found on [docs.usebrainbase.xyz](https://docs.usebrainbase.xyz). The full API of this library can be found in [api.md](api.md).
+The REST API documentation can be found on [docs.usebrainbase.com](https://docs.usebrainbase.com). The full API of this library can be found in [api.md](api.md).
## Installation
@@ -62,6 +63,37 @@ asyncio.run(main())
Functionality between the synchronous and asynchronous clients is otherwise identical.
+### With aiohttp
+
+By default, the async client uses `httpx` for HTTP requests. However, for improved concurrency performance you may also use `aiohttp` as the HTTP backend.
+
+You can enable this by installing `aiohttp`:
+
+```sh
+# install from PyPI
+pip install brainbase-labs[aiohttp]
+```
+
+Then you can enable it by instantiating the client with `http_client=DefaultAioHttpClient()`:
+
+```python
+import os
+import asyncio
+from brainbase import DefaultAioHttpClient
+from brainbase import AsyncBrainbase
+
+
+async def main() -> None:
+ async with AsyncBrainbase(
+ api_key=os.environ.get("API_KEY"), # This is the default and can be omitted
+ http_client=DefaultAioHttpClient(),
+ ) as client:
+ workers = await client.workers.list()
+
+
+asyncio.run(main())
+```
+
## Using types
Nested request parameters are [TypedDicts](https://docs.python.org/3/library/typing.html#typing.TypedDict). Responses are [Pydantic models](https://docs.pydantic.dev) which also provide helper methods for things like:
@@ -136,7 +168,7 @@ client.with_options(max_retries=5).workers.list()
### Timeouts
By default requests time out after 1 minute. You can configure this with a `timeout` option,
-which accepts a float or an [`httpx.Timeout`](https://www.python-httpx.org/advanced/#fine-tuning-the-configuration) object:
+which accepts a float or an [`httpx.Timeout`](https://www.python-httpx.org/advanced/timeouts/#fine-tuning-the-configuration) object:
```python
from brainbase import Brainbase
@@ -322,7 +354,7 @@ print(brainbase.__version__)
## Requirements
-Python 3.8 or higher.
+Python 3.9 or higher.
## Contributing
diff --git a/SECURITY.md b/SECURITY.md
index 752bd79a..f5c1b88f 100644
--- a/SECURITY.md
+++ b/SECURITY.md
@@ -2,9 +2,9 @@
## Reporting Security Issues
-This SDK is generated by [Stainless Software Inc](http://stainlessapi.com). Stainless takes security seriously, and encourages you to report any security vulnerability promptly so that appropriate action can be taken.
+This SDK is generated by [Stainless Software Inc](http://stainless.com). Stainless takes security seriously, and encourages you to report any security vulnerability promptly so that appropriate action can be taken.
-To report a security issue, please contact the Stainless team at security@stainlessapi.com.
+To report a security issue, please contact the Stainless team at security@stainless.com.
## Responsible Disclosure
@@ -16,11 +16,11 @@ before making any information public.
## Reporting Non-SDK Related Security Issues
If you encounter security issues that are not directly related to SDKs but pertain to the services
-or products provided by Brainbase please follow the respective company's security reporting guidelines.
+or products provided by Brainbase, please follow the respective company's security reporting guidelines.
### Brainbase Terms and Policies
-Please contact dev@brainbase.com for any questions or concerns regarding security of our services.
+Please contact dev-feedback@brainbase.com for any questions or concerns regarding the security of our services.
---
diff --git a/api.md b/api.md
index 2022bf88..ecd4ff83 100644
--- a/api.md
+++ b/api.md
@@ -26,19 +26,14 @@ Methods:
Types:
```python
-from brainbase.types.workers.deployments import (
- VoiceCreateResponse,
- VoiceRetrieveResponse,
- VoiceUpdateResponse,
- VoiceListResponse,
-)
+from brainbase.types.workers.deployments import VoiceDeployment, VoiceListResponse
```
Methods:
-- client.workers.deployments.voice.create(worker_id, \*\*params) -> VoiceCreateResponse
-- client.workers.deployments.voice.retrieve(deployment_id, \*, worker_id) -> VoiceRetrieveResponse
-- client.workers.deployments.voice.update(deployment_id, \*, worker_id, \*\*params) -> VoiceUpdateResponse
+- client.workers.deployments.voice.create(worker_id, \*\*params) -> VoiceDeployment
+- client.workers.deployments.voice.retrieve(deployment_id, \*, worker_id) -> VoiceDeployment
+- client.workers.deployments.voice.update(deployment_id, \*, worker_id, \*\*params) -> VoiceDeployment
- client.workers.deployments.voice.list(worker_id) -> VoiceListResponse
- client.workers.deployments.voice.delete(deployment_id, \*, worker_id) -> None
diff --git a/bin/check-release-environment b/bin/check-release-environment
index 59422a48..b845b0f4 100644
--- a/bin/check-release-environment
+++ b/bin/check-release-environment
@@ -3,7 +3,7 @@
errors=()
if [ -z "${PYPI_TOKEN}" ]; then
- errors+=("The BRAINBASE_PYPI_TOKEN secret has not been set. Please set it in either this repository's secrets or your organization secrets.")
+ errors+=("The PYPI_TOKEN secret has not been set. Please set it in either this repository's secrets or your organization secrets.")
fi
lenErrors=${#errors[@]}
diff --git a/bin/publish-pypi b/bin/publish-pypi
index 05bfccbb..826054e9 100644
--- a/bin/publish-pypi
+++ b/bin/publish-pypi
@@ -3,7 +3,4 @@
set -eux
mkdir -p dist
rye build --clean
-# Patching importlib-metadata version until upstream library version is updated
-# https://github.com/pypa/twine/issues/977#issuecomment-2189800841
-"$HOME/.rye/self/bin/python3" -m pip install 'importlib-metadata==7.2.1'
rye publish --yes --token=$PYPI_TOKEN
diff --git a/mypy.ini b/mypy.ini
deleted file mode 100644
index e062bd87..00000000
--- a/mypy.ini
+++ /dev/null
@@ -1,50 +0,0 @@
-[mypy]
-pretty = True
-show_error_codes = True
-
-# Exclude _files.py because mypy isn't smart enough to apply
-# the correct type narrowing and as this is an internal module
-# it's fine to just use Pyright.
-#
-# We also exclude our `tests` as mypy doesn't always infer
-# types correctly and Pyright will still catch any type errors.
-exclude = ^(src/brainbase/_files\.py|_dev/.*\.py|tests/.*)$
-
-strict_equality = True
-implicit_reexport = True
-check_untyped_defs = True
-no_implicit_optional = True
-
-warn_return_any = True
-warn_unreachable = True
-warn_unused_configs = True
-
-# Turn these options off as it could cause conflicts
-# with the Pyright options.
-warn_unused_ignores = False
-warn_redundant_casts = False
-
-disallow_any_generics = True
-disallow_untyped_defs = True
-disallow_untyped_calls = True
-disallow_subclassing_any = True
-disallow_incomplete_defs = True
-disallow_untyped_decorators = True
-cache_fine_grained = True
-
-# By default, mypy reports an error if you assign a value to the result
-# of a function call that doesn't return anything. We do this in our test
-# cases:
-# ```
-# result = ...
-# assert result is None
-# ```
-# Changing this codegen to make mypy happy would increase complexity
-# and would not be worth it.
-disable_error_code = func-returns-value,overload-cannot-match
-
-# https://github.com/python/mypy/issues/12162
-[mypy.overrides]
-module = "black.files.*"
-ignore_errors = true
-ignore_missing_imports = true
diff --git a/pyproject.toml b/pyproject.toml
index ac778a61..300483ea 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,29 +1,32 @@
[project]
name = "brainbase-labs"
-version = "4.0.0"
+version = "4.1.0"
description = "The official Python library for the brainbase API"
dynamic = ["readme"]
license = "Apache-2.0"
authors = [
-{ name = "Brainbase", email = "dev@brainbase.com" },
+{ name = "Brainbase", email = "dev-feedback@brainbase.com" },
]
+
dependencies = [
- "httpx>=0.23.0, <1",
- "pydantic>=1.9.0, <3",
- "typing-extensions>=4.10, <5",
- "anyio>=3.5.0, <5",
- "distro>=1.7.0, <2",
- "sniffio",
+ "httpx>=0.23.0, <1",
+ "pydantic>=1.9.0, <3",
+ "typing-extensions>=4.10, <5",
+ "anyio>=3.5.0, <5",
+ "distro>=1.7.0, <2",
+ "sniffio",
]
-requires-python = ">= 3.8"
+
+requires-python = ">= 3.9"
classifiers = [
"Typing :: Typed",
"Intended Audience :: Developers",
- "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
+ "Programming Language :: Python :: 3.13",
+ "Programming Language :: Python :: 3.14",
"Operating System :: OS Independent",
"Operating System :: POSIX",
"Operating System :: MacOS",
@@ -37,14 +40,15 @@ classifiers = [
Homepage = "https://github.com/BrainbaseHQ/brainbase-python-sdk"
Repository = "https://github.com/BrainbaseHQ/brainbase-python-sdk"
-
+[project.optional-dependencies]
+aiohttp = ["aiohttp", "httpx_aiohttp>=0.1.9"]
[tool.rye]
managed = true
# version pins are in requirements-dev.lock
dev-dependencies = [
- "pyright>=1.1.359",
- "mypy",
+ "pyright==1.1.399",
+ "mypy==1.17",
"respx",
"pytest",
"pytest-asyncio",
@@ -54,7 +58,7 @@ dev-dependencies = [
"dirty-equals>=0.6.0",
"importlib-metadata>=6.7.0",
"rich>=13.7.1",
- "nest_asyncio==1.6.0",
+ "pytest-xdist>=3.6.1",
]
[tool.rye.scripts]
@@ -87,7 +91,7 @@ typecheck = { chain = [
"typecheck:mypy" = "mypy ."
[build-system]
-requires = ["hatchling", "hatch-fancy-pypi-readme"]
+requires = ["hatchling==1.26.3", "hatch-fancy-pypi-readme"]
build-backend = "hatchling.build"
[tool.hatch.build]
@@ -126,7 +130,7 @@ replacement = '[\1](https://github.com/BrainbaseHQ/brainbase-python-sdk/tree/mai
[tool.pytest.ini_options]
testpaths = ["tests"]
-addopts = "--tb=short"
+addopts = "--tb=short -n auto"
xfail_strict = true
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "session"
@@ -139,24 +143,77 @@ filterwarnings = [
# there are a couple of flags that are still disabled by
# default in strict mode as they are experimental and niche.
typeCheckingMode = "strict"
-pythonVersion = "3.8"
+pythonVersion = "3.9"
exclude = [
"_dev",
".venv",
".nox",
+ ".git",
]
reportImplicitOverride = true
+reportOverlappingOverload = false
reportImportCycles = false
reportPrivateUsage = false
+[tool.mypy]
+pretty = true
+show_error_codes = true
+
+# Exclude _files.py because mypy isn't smart enough to apply
+# the correct type narrowing and as this is an internal module
+# it's fine to just use Pyright.
+#
+# We also exclude our `tests` as mypy doesn't always infer
+# types correctly and Pyright will still catch any type errors.
+exclude = ['src/brainbase/_files.py', '_dev/.*.py', 'tests/.*']
+
+strict_equality = true
+implicit_reexport = true
+check_untyped_defs = true
+no_implicit_optional = true
+
+warn_return_any = true
+warn_unreachable = true
+warn_unused_configs = true
+
+# Turn these options off as it could cause conflicts
+# with the Pyright options.
+warn_unused_ignores = false
+warn_redundant_casts = false
+
+disallow_any_generics = true
+disallow_untyped_defs = true
+disallow_untyped_calls = true
+disallow_subclassing_any = true
+disallow_incomplete_defs = true
+disallow_untyped_decorators = true
+cache_fine_grained = true
+
+# By default, mypy reports an error if you assign a value to the result
+# of a function call that doesn't return anything. We do this in our test
+# cases:
+# ```
+# result = ...
+# assert result is None
+# ```
+# Changing this codegen to make mypy happy would increase complexity
+# and would not be worth it.
+disable_error_code = "func-returns-value,overload-cannot-match"
+
+# https://github.com/python/mypy/issues/12162
+[[tool.mypy.overrides]]
+module = "black.files.*"
+ignore_errors = true
+ignore_missing_imports = true
+
[tool.ruff]
line-length = 120
output-format = "grouped"
-target-version = "py37"
+target-version = "py38"
[tool.ruff.format]
docstring-code-format = true
@@ -169,6 +226,8 @@ select = [
"B",
# remove unused imports
"F401",
+ # check for missing future annotations
+ "FA102",
# bare except statements
"E722",
# unused arguments
@@ -191,6 +250,8 @@ unfixable = [
"T203",
]
+extend-safe-fixes = ["FA102"]
+
[tool.ruff.lint.flake8-tidy-imports.banned-api]
"functools.lru_cache".msg = "This function does not retain type information for the wrapped function's arguments; The `lru_cache` function from `_utils` should be used instead"
diff --git a/requirements-dev.lock b/requirements-dev.lock
index 2ccbd89c..8a56d82f 100644
--- a/requirements-dev.lock
+++ b/requirements-dev.lock
@@ -7,97 +7,143 @@
# all-features: true
# with-sources: false
# generate-hashes: false
+# universal: false
-e file:.
-annotated-types==0.6.0
+aiohappyeyeballs==2.6.1
+ # via aiohttp
+aiohttp==3.13.2
+ # via brainbase-labs
+ # via httpx-aiohttp
+aiosignal==1.4.0
+ # via aiohttp
+annotated-types==0.7.0
# via pydantic
-anyio==4.4.0
+anyio==4.12.0
# via brainbase-labs
# via httpx
-argcomplete==3.1.2
+argcomplete==3.6.3
+ # via nox
+async-timeout==5.0.1
+ # via aiohttp
+attrs==25.4.0
+ # via aiohttp
# via nox
-certifi==2023.7.22
+backports-asyncio-runner==1.2.0
+ # via pytest-asyncio
+certifi==2025.11.12
# via httpcore
# via httpx
-colorlog==6.7.0
+colorlog==6.10.1
+ # via nox
+dependency-groups==1.3.1
# via nox
-dirty-equals==0.6.0
-distlib==0.3.7
+dirty-equals==0.11
+distlib==0.4.0
# via virtualenv
-distro==1.8.0
+distro==1.9.0
# via brainbase-labs
-exceptiongroup==1.2.2
+exceptiongroup==1.3.1
# via anyio
# via pytest
-filelock==3.12.4
+execnet==2.1.2
+ # via pytest-xdist
+filelock==3.19.1
# via virtualenv
-h11==0.14.0
+frozenlist==1.8.0
+ # via aiohttp
+ # via aiosignal
+h11==0.16.0
# via httpcore
-httpcore==1.0.2
+httpcore==1.0.9
# via httpx
httpx==0.28.1
# via brainbase-labs
+ # via httpx-aiohttp
# via respx
-idna==3.4
+httpx-aiohttp==0.1.9
+ # via brainbase-labs
+humanize==4.13.0
+ # via nox
+idna==3.11
# via anyio
# via httpx
-importlib-metadata==7.0.0
-iniconfig==2.0.0
+ # via yarl
+importlib-metadata==8.7.0
+iniconfig==2.1.0
# via pytest
markdown-it-py==3.0.0
# via rich
mdurl==0.1.2
# via markdown-it-py
-mypy==1.14.1
-mypy-extensions==1.0.0
+multidict==6.7.0
+ # via aiohttp
+ # via yarl
+mypy==1.17.0
+mypy-extensions==1.1.0
# via mypy
-nest-asyncio==1.6.0
-nodeenv==1.8.0
+nodeenv==1.9.1
# via pyright
-nox==2023.4.22
-packaging==23.2
+nox==2025.11.12
+packaging==25.0
+ # via dependency-groups
# via nox
# via pytest
-platformdirs==3.11.0
+pathspec==0.12.1
+ # via mypy
+platformdirs==4.4.0
# via virtualenv
-pluggy==1.5.0
+pluggy==1.6.0
# via pytest
-pydantic==2.10.3
+propcache==0.4.1
+ # via aiohttp
+ # via yarl
+pydantic==2.12.5
# via brainbase-labs
-pydantic-core==2.27.1
+pydantic-core==2.41.5
# via pydantic
-pygments==2.18.0
+pygments==2.19.2
+ # via pytest
# via rich
-pyright==1.1.392.post0
-pytest==8.3.3
+pyright==1.1.399
+pytest==8.4.2
# via pytest-asyncio
-pytest-asyncio==0.24.0
-python-dateutil==2.8.2
+ # via pytest-xdist
+pytest-asyncio==1.2.0
+pytest-xdist==3.8.0
+python-dateutil==2.9.0.post0
# via time-machine
-pytz==2023.3.post1
- # via dirty-equals
respx==0.22.0
-rich==13.7.1
-ruff==0.9.4
-setuptools==68.2.2
- # via nodeenv
-six==1.16.0
+rich==14.2.0
+ruff==0.14.7
+six==1.17.0
# via python-dateutil
-sniffio==1.3.0
- # via anyio
+sniffio==1.3.1
# via brainbase-labs
-time-machine==2.9.0
-tomli==2.0.2
+time-machine==2.19.0
+tomli==2.3.0
+ # via dependency-groups
# via mypy
+ # via nox
# via pytest
-typing-extensions==4.12.2
+typing-extensions==4.15.0
+ # via aiosignal
# via anyio
# via brainbase-labs
+ # via exceptiongroup
+ # via multidict
# via mypy
# via pydantic
# via pydantic-core
# via pyright
-virtualenv==20.24.5
+ # via pytest-asyncio
+ # via typing-inspection
+ # via virtualenv
+typing-inspection==0.4.2
+ # via pydantic
+virtualenv==20.35.4
# via nox
-zipp==3.17.0
+yarl==1.22.0
+ # via aiohttp
+zipp==3.23.0
# via importlib-metadata
diff --git a/requirements.lock b/requirements.lock
index ab585024..46d844a3 100644
--- a/requirements.lock
+++ b/requirements.lock
@@ -7,38 +7,70 @@
# all-features: true
# with-sources: false
# generate-hashes: false
+# universal: false
-e file:.
-annotated-types==0.6.0
+aiohappyeyeballs==2.6.1
+ # via aiohttp
+aiohttp==3.13.2
+ # via brainbase-labs
+ # via httpx-aiohttp
+aiosignal==1.4.0
+ # via aiohttp
+annotated-types==0.7.0
# via pydantic
-anyio==4.4.0
+anyio==4.12.0
# via brainbase-labs
# via httpx
-certifi==2023.7.22
+async-timeout==5.0.1
+ # via aiohttp
+attrs==25.4.0
+ # via aiohttp
+certifi==2025.11.12
# via httpcore
# via httpx
-distro==1.8.0
+distro==1.9.0
# via brainbase-labs
-exceptiongroup==1.2.2
+exceptiongroup==1.3.1
# via anyio
-h11==0.14.0
+frozenlist==1.8.0
+ # via aiohttp
+ # via aiosignal
+h11==0.16.0
# via httpcore
-httpcore==1.0.2
+httpcore==1.0.9
# via httpx
httpx==0.28.1
# via brainbase-labs
-idna==3.4
+ # via httpx-aiohttp
+httpx-aiohttp==0.1.9
+ # via brainbase-labs
+idna==3.11
# via anyio
# via httpx
-pydantic==2.10.3
+ # via yarl
+multidict==6.7.0
+ # via aiohttp
+ # via yarl
+propcache==0.4.1
+ # via aiohttp
+ # via yarl
+pydantic==2.12.5
# via brainbase-labs
-pydantic-core==2.27.1
+pydantic-core==2.41.5
# via pydantic
-sniffio==1.3.0
- # via anyio
+sniffio==1.3.1
# via brainbase-labs
-typing-extensions==4.12.2
+typing-extensions==4.15.0
+ # via aiosignal
# via anyio
# via brainbase-labs
+ # via exceptiongroup
+ # via multidict
# via pydantic
# via pydantic-core
+ # via typing-inspection
+typing-inspection==0.4.2
+ # via pydantic
+yarl==1.22.0
+ # via aiohttp
diff --git a/scripts/bootstrap b/scripts/bootstrap
index e84fe62c..b430fee3 100755
--- a/scripts/bootstrap
+++ b/scripts/bootstrap
@@ -4,10 +4,18 @@ set -e
cd "$(dirname "$0")/.."
-if ! command -v rye >/dev/null 2>&1 && [ -f "Brewfile" ] && [ "$(uname -s)" = "Darwin" ]; then
+if [ -f "Brewfile" ] && [ "$(uname -s)" = "Darwin" ] && [ "$SKIP_BREW" != "1" ] && [ -t 0 ]; then
brew bundle check >/dev/null 2>&1 || {
- echo "==> Installing Homebrew dependencies…"
- brew bundle
+ echo -n "==> Install Homebrew dependencies? (y/N): "
+ read -r response
+ case "$response" in
+ [yY][eE][sS]|[yY])
+ brew bundle
+ ;;
+ *)
+ ;;
+ esac
+ echo
}
fi
diff --git a/scripts/lint b/scripts/lint
index fa0c0399..83623cb6 100755
--- a/scripts/lint
+++ b/scripts/lint
@@ -4,8 +4,13 @@ set -e
cd "$(dirname "$0")/.."
-echo "==> Running lints"
-rye run lint
+if [ "$1" = "--fix" ]; then
+ echo "==> Running lints with --fix"
+ rye run fix:ruff
+else
+ echo "==> Running lints"
+ rye run lint
+fi
echo "==> Making sure it imports"
rye run python -c 'import brainbase'
diff --git a/scripts/mock b/scripts/mock
index d2814ae6..0b28f6ea 100755
--- a/scripts/mock
+++ b/scripts/mock
@@ -21,7 +21,7 @@ echo "==> Starting mock server with URL ${URL}"
# Run prism mock on the given spec
if [ "$1" == "--daemon" ]; then
- npm exec --package=@stainless-api/prism-cli@5.8.5 -- prism mock "$URL" &> .prism.log &
+ npm exec --package=@stainless-api/prism-cli@5.15.0 -- prism mock "$URL" &> .prism.log &
# Wait for server to come online
echo -n "Waiting for server"
@@ -37,5 +37,5 @@ if [ "$1" == "--daemon" ]; then
echo
else
- npm exec --package=@stainless-api/prism-cli@5.8.5 -- prism mock "$URL"
+ npm exec --package=@stainless-api/prism-cli@5.15.0 -- prism mock "$URL"
fi
diff --git a/scripts/test b/scripts/test
index 4fa5698b..dbeda2d2 100755
--- a/scripts/test
+++ b/scripts/test
@@ -43,7 +43,7 @@ elif ! prism_is_running ; then
echo -e "To run the server, pass in the path or url of your OpenAPI"
echo -e "spec to the prism command:"
echo
- echo -e " \$ ${YELLOW}npm exec --package=@stoplight/prism-cli@~5.3.2 -- prism mock path/to/your.openapi.yml${NC}"
+ echo -e " \$ ${YELLOW}npm exec --package=@stainless-api/prism-cli@5.15.0 -- prism mock path/to/your.openapi.yml${NC}"
echo
exit 1
@@ -52,6 +52,8 @@ else
echo
fi
+export DEFER_PYDANTIC_BUILD=false
+
echo "==> Running tests"
rye run pytest "$@"
diff --git a/scripts/utils/upload-artifact.sh b/scripts/utils/upload-artifact.sh
new file mode 100755
index 00000000..64d4441e
--- /dev/null
+++ b/scripts/utils/upload-artifact.sh
@@ -0,0 +1,27 @@
+#!/usr/bin/env bash
+set -exuo pipefail
+
+FILENAME=$(basename dist/*.whl)
+
+RESPONSE=$(curl -X POST "$URL?filename=$FILENAME" \
+ -H "Authorization: Bearer $AUTH" \
+ -H "Content-Type: application/json")
+
+SIGNED_URL=$(echo "$RESPONSE" | jq -r '.url')
+
+if [[ "$SIGNED_URL" == "null" ]]; then
+ echo -e "\033[31mFailed to get signed URL.\033[0m"
+ exit 1
+fi
+
+UPLOAD_RESPONSE=$(curl -v -X PUT \
+ -H "Content-Type: binary/octet-stream" \
+ --data-binary "@dist/$FILENAME" "$SIGNED_URL" 2>&1)
+
+if echo "$UPLOAD_RESPONSE" | grep -q "HTTP/[0-9.]* 200"; then
+ echo -e "\033[32mUploaded build to Stainless storage.\033[0m"
+ echo -e "\033[32mInstallation: pip install 'https://pkg.stainless.com/s/brainbase-python/$SHA/$FILENAME'\033[0m"
+else
+ echo -e "\033[31mFailed to upload artifact.\033[0m"
+ exit 1
+fi
diff --git a/src/brainbase/__init__.py b/src/brainbase/__init__.py
index 9f3a57b1..0d991ee6 100644
--- a/src/brainbase/__init__.py
+++ b/src/brainbase/__init__.py
@@ -1,7 +1,9 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+import typing as _t
+
from . import types
-from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, Transport, ProxiesTypes
+from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, Transport, ProxiesTypes, omit, not_given
from ._utils import file_from_path
from ._client import (
Client,
@@ -34,7 +36,7 @@
UnprocessableEntityError,
APIResponseValidationError,
)
-from ._base_client import DefaultHttpxClient, DefaultAsyncHttpxClient
+from ._base_client import DefaultHttpxClient, DefaultAioHttpClient, DefaultAsyncHttpxClient
from ._utils._logs import setup_logging as _setup_logging
__all__ = [
@@ -46,7 +48,9 @@
"ProxiesTypes",
"NotGiven",
"NOT_GIVEN",
+ "not_given",
"Omit",
+ "omit",
"BrainbaseError",
"APIError",
"APIStatusError",
@@ -76,8 +80,12 @@
"DEFAULT_CONNECTION_LIMITS",
"DefaultHttpxClient",
"DefaultAsyncHttpxClient",
+ "DefaultAioHttpClient",
]
+if not _t.TYPE_CHECKING:
+ from ._utils._resources_proxy import resources as resources
+
_setup_logging()
# Update the __module__ attribute for exported symbols so that
diff --git a/src/brainbase/_base_client.py b/src/brainbase/_base_client.py
index 9bb716ae..7849a30a 100644
--- a/src/brainbase/_base_client.py
+++ b/src/brainbase/_base_client.py
@@ -9,7 +9,6 @@
import inspect
import logging
import platform
-import warnings
import email.utils
from types import TracebackType
from random import random
@@ -36,14 +35,13 @@
import httpx
import distro
import pydantic
-from httpx import URL, Limits
+from httpx import URL
from pydantic import PrivateAttr
from . import _exceptions
from ._qs import Querystring
from ._files import to_httpx_files, async_to_httpx_files
from ._types import (
- NOT_GIVEN,
Body,
Omit,
Query,
@@ -51,19 +49,17 @@
Timeout,
NotGiven,
ResponseT,
- Transport,
AnyMapping,
PostParser,
- ProxiesTypes,
RequestFiles,
HttpxSendArgs,
- AsyncTransport,
RequestOptions,
HttpxRequestFiles,
ModelBuilderProtocol,
+ not_given,
)
from ._utils import is_dict, is_list, asyncify, is_given, lru_cache, is_mapping
-from ._compat import model_copy, model_dump
+from ._compat import PYDANTIC_V1, model_copy, model_dump
from ._models import GenericModel, FinalRequestOptions, validate_type, construct_type
from ._response import (
APIResponse,
@@ -102,7 +98,11 @@
_AsyncStreamT = TypeVar("_AsyncStreamT", bound=AsyncStream[Any])
if TYPE_CHECKING:
- from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
+ from httpx._config import (
+ DEFAULT_TIMEOUT_CONFIG, # pyright: ignore[reportPrivateImportUsage]
+ )
+
+ HTTPX_DEFAULT_TIMEOUT = DEFAULT_TIMEOUT_CONFIG
else:
try:
from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
@@ -119,6 +119,7 @@ class PageInfo:
url: URL | NotGiven
params: Query | NotGiven
+ json: Body | NotGiven
@overload
def __init__(
@@ -134,19 +135,30 @@ def __init__(
params: Query,
) -> None: ...
+ @overload
+ def __init__(
+ self,
+ *,
+ json: Body,
+ ) -> None: ...
+
def __init__(
self,
*,
- url: URL | NotGiven = NOT_GIVEN,
- params: Query | NotGiven = NOT_GIVEN,
+ url: URL | NotGiven = not_given,
+ json: Body | NotGiven = not_given,
+ params: Query | NotGiven = not_given,
) -> None:
self.url = url
+ self.json = json
self.params = params
@override
def __repr__(self) -> str:
if self.url:
return f"{self.__class__.__name__}(url={self.url})"
+ if self.json:
+ return f"{self.__class__.__name__}(json={self.json})"
return f"{self.__class__.__name__}(params={self.params})"
@@ -195,6 +207,19 @@ def _info_to_options(self, info: PageInfo) -> FinalRequestOptions:
options.url = str(url)
return options
+ if not isinstance(info.json, NotGiven):
+ if not is_mapping(info.json):
+ raise TypeError("Pagination is only supported with mappings")
+
+ if not options.json_data:
+ options.json_data = {**info.json}
+ else:
+ if not is_mapping(options.json_data):
+ raise TypeError("Pagination is only supported with mappings")
+
+ options.json_data = {**options.json_data, **info.json}
+ return options
+
raise ValueError("Unexpected PageInfo state")
@@ -207,6 +232,9 @@ def _set_private_attributes(
model: Type[_T],
options: FinalRequestOptions,
) -> None:
+ if (not PYDANTIC_V1) and getattr(self, "__pydantic_private__", None) is None:
+ self.__pydantic_private__ = {}
+
self._model = model
self._client = client
self._options = options
@@ -292,6 +320,9 @@ def _set_private_attributes(
client: AsyncAPIClient,
options: FinalRequestOptions,
) -> None:
+ if (not PYDANTIC_V1) and getattr(self, "__pydantic_private__", None) is None:
+ self.__pydantic_private__ = {}
+
self._model = model
self._client = client
self._options = options
@@ -331,9 +362,6 @@ class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
_base_url: URL
max_retries: int
timeout: Union[float, Timeout, None]
- _limits: httpx.Limits
- _proxies: ProxiesTypes | None
- _transport: Transport | AsyncTransport | None
_strict_response_validation: bool
_idempotency_header: str | None
_default_stream_cls: type[_DefaultStreamT] | None = None
@@ -346,9 +374,6 @@ def __init__(
_strict_response_validation: bool,
max_retries: int = DEFAULT_MAX_RETRIES,
timeout: float | Timeout | None = DEFAULT_TIMEOUT,
- limits: httpx.Limits,
- transport: Transport | AsyncTransport | None,
- proxies: ProxiesTypes | None,
custom_headers: Mapping[str, str] | None = None,
custom_query: Mapping[str, object] | None = None,
) -> None:
@@ -356,9 +381,6 @@ def __init__(
self._base_url = self._enforce_trailing_slash(URL(base_url))
self.max_retries = max_retries
self.timeout = timeout
- self._limits = limits
- self._proxies = proxies
- self._transport = transport
self._custom_headers = custom_headers or {}
self._custom_query = custom_query or {}
self._strict_response_validation = _strict_response_validation
@@ -415,13 +437,20 @@ def _build_headers(self, options: FinalRequestOptions, *, retries_taken: int = 0
headers = httpx.Headers(headers_dict)
idempotency_header = self._idempotency_header
- if idempotency_header and options.method.lower() != "get" and idempotency_header not in headers:
- headers[idempotency_header] = options.idempotency_key or self._idempotency_key()
+ if idempotency_header and options.idempotency_key and idempotency_header not in headers:
+ headers[idempotency_header] = options.idempotency_key
- # Don't set the retry count header if it was already set or removed by the caller. We check
+ # Don't set these headers if they were already set or removed by the caller. We check
# `custom_headers`, which can contain `Omit()`, instead of `headers` to account for the removal case.
- if "x-stainless-retry-count" not in (header.lower() for header in custom_headers):
+ lower_custom_headers = [header.lower() for header in custom_headers]
+ if "x-stainless-retry-count" not in lower_custom_headers:
headers["x-stainless-retry-count"] = str(retries_taken)
+ if "x-stainless-read-timeout" not in lower_custom_headers:
+ timeout = self.timeout if isinstance(options.timeout, NotGiven) else options.timeout
+ if isinstance(timeout, Timeout):
+ timeout = timeout.read
+ if timeout is not None:
+ headers["x-stainless-read-timeout"] = str(timeout)
return headers
@@ -500,6 +529,18 @@ def _build_request(
# work around https://github.com/encode/httpx/discussions/2880
kwargs["extensions"] = {"sni_hostname": prepared_url.host.replace("_", "-")}
+ is_body_allowed = options.method.lower() != "get"
+
+ if is_body_allowed:
+ if isinstance(json_data, bytes):
+ kwargs["content"] = json_data
+ else:
+ kwargs["json"] = json_data if is_given(json_data) else None
+ kwargs["files"] = files
+ else:
+ headers.pop("Content-Type", None)
+ kwargs.pop("data", None)
+
# TODO: report this error to httpx
return self._client.build_request( # pyright: ignore[reportUnknownMemberType]
headers=headers,
@@ -511,8 +552,6 @@ def _build_request(
# so that passing a `TypedDict` doesn't cause an error.
# https://github.com/microsoft/pyright/issues/3526#event-6715453066
params=self.qs.stringify(cast(Mapping[str, Any], params)) if params else None,
- json=json_data,
- files=files,
**kwargs,
)
@@ -556,7 +595,7 @@ def _maybe_override_cast_to(self, cast_to: type[ResponseT], options: FinalReques
# we internally support defining a temporary header to override the
# default `cast_to` type for use with `.with_raw_response` and `.with_streaming_response`
# see _response.py for implementation details
- override_cast_to = headers.pop(OVERRIDE_CAST_TO_HEADER, NOT_GIVEN)
+ override_cast_to = headers.pop(OVERRIDE_CAST_TO_HEADER, not_given)
if is_given(override_cast_to):
options.headers = headers
return cast(Type[ResponseT], override_cast_to)
@@ -786,47 +825,12 @@ def __init__(
version: str,
base_url: str | URL,
max_retries: int = DEFAULT_MAX_RETRIES,
- timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
- transport: Transport | None = None,
- proxies: ProxiesTypes | None = None,
- limits: Limits | None = None,
+ timeout: float | Timeout | None | NotGiven = not_given,
http_client: httpx.Client | None = None,
custom_headers: Mapping[str, str] | None = None,
custom_query: Mapping[str, object] | None = None,
_strict_response_validation: bool,
) -> None:
- kwargs: dict[str, Any] = {}
- if limits is not None:
- warnings.warn(
- "The `connection_pool_limits` argument is deprecated. The `http_client` argument should be passed instead",
- category=DeprecationWarning,
- stacklevel=3,
- )
- if http_client is not None:
- raise ValueError("The `http_client` argument is mutually exclusive with `connection_pool_limits`")
- else:
- limits = DEFAULT_CONNECTION_LIMITS
-
- if transport is not None:
- kwargs["transport"] = transport
- warnings.warn(
- "The `transport` argument is deprecated. The `http_client` argument should be passed instead",
- category=DeprecationWarning,
- stacklevel=3,
- )
- if http_client is not None:
- raise ValueError("The `http_client` argument is mutually exclusive with `transport`")
-
- if proxies is not None:
- kwargs["proxies"] = proxies
- warnings.warn(
- "The `proxies` argument is deprecated. The `http_client` argument should be passed instead",
- category=DeprecationWarning,
- stacklevel=3,
- )
- if http_client is not None:
- raise ValueError("The `http_client` argument is mutually exclusive with `proxies`")
-
if not is_given(timeout):
# if the user passed in a custom http client with a non-default
# timeout set then we use that timeout.
@@ -847,12 +851,9 @@ def __init__(
super().__init__(
version=version,
- limits=limits,
# cast to a valid type because mypy doesn't understand our type narrowing
timeout=cast(Timeout, timeout),
- proxies=proxies,
base_url=base_url,
- transport=transport,
max_retries=max_retries,
custom_query=custom_query,
custom_headers=custom_headers,
@@ -862,9 +863,6 @@ def __init__(
base_url=base_url,
# cast to a valid type because mypy doesn't understand our type narrowing
timeout=cast(Timeout, timeout),
- limits=limits,
- follow_redirects=True,
- **kwargs, # type: ignore
)
def is_closed(self) -> bool:
@@ -914,7 +912,6 @@ def request(
self,
cast_to: Type[ResponseT],
options: FinalRequestOptions,
- remaining_retries: Optional[int] = None,
*,
stream: Literal[True],
stream_cls: Type[_StreamT],
@@ -925,7 +922,6 @@ def request(
self,
cast_to: Type[ResponseT],
options: FinalRequestOptions,
- remaining_retries: Optional[int] = None,
*,
stream: Literal[False] = False,
) -> ResponseT: ...
@@ -935,7 +931,6 @@ def request(
self,
cast_to: Type[ResponseT],
options: FinalRequestOptions,
- remaining_retries: Optional[int] = None,
*,
stream: bool = False,
stream_cls: Type[_StreamT] | None = None,
@@ -945,121 +940,112 @@ def request(
self,
cast_to: Type[ResponseT],
options: FinalRequestOptions,
- remaining_retries: Optional[int] = None,
*,
stream: bool = False,
stream_cls: type[_StreamT] | None = None,
) -> ResponseT | _StreamT:
- if remaining_retries is not None:
- retries_taken = options.get_max_retries(self.max_retries) - remaining_retries
- else:
- retries_taken = 0
-
- return self._request(
- cast_to=cast_to,
- options=options,
- stream=stream,
- stream_cls=stream_cls,
- retries_taken=retries_taken,
- )
+ cast_to = self._maybe_override_cast_to(cast_to, options)
- def _request(
- self,
- *,
- cast_to: Type[ResponseT],
- options: FinalRequestOptions,
- retries_taken: int,
- stream: bool,
- stream_cls: type[_StreamT] | None,
- ) -> ResponseT | _StreamT:
# create a copy of the options we were given so that if the
# options are mutated later & we then retry, the retries are
# given the original options
input_options = model_copy(options)
+ if input_options.idempotency_key is None and input_options.method.lower() != "get":
+ # ensure the idempotency key is reused between requests
+ input_options.idempotency_key = self._idempotency_key()
- cast_to = self._maybe_override_cast_to(cast_to, options)
- options = self._prepare_options(options)
+ response: httpx.Response | None = None
+ max_retries = input_options.get_max_retries(self.max_retries)
- remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
- request = self._build_request(options, retries_taken=retries_taken)
- self._prepare_request(request)
+ retries_taken = 0
+ for retries_taken in range(max_retries + 1):
+ options = model_copy(input_options)
+ options = self._prepare_options(options)
- kwargs: HttpxSendArgs = {}
- if self.custom_auth is not None:
- kwargs["auth"] = self.custom_auth
+ remaining_retries = max_retries - retries_taken
+ request = self._build_request(options, retries_taken=retries_taken)
+ self._prepare_request(request)
- log.debug("Sending HTTP Request: %s %s", request.method, request.url)
-
- try:
- response = self._client.send(
- request,
- stream=stream or self._should_stream_response_body(request=request),
- **kwargs,
- )
- except httpx.TimeoutException as err:
- log.debug("Encountered httpx.TimeoutException", exc_info=True)
+ kwargs: HttpxSendArgs = {}
+ if self.custom_auth is not None:
+ kwargs["auth"] = self.custom_auth
- if remaining_retries > 0:
- return self._retry_request(
- input_options,
- cast_to,
- retries_taken=retries_taken,
- stream=stream,
- stream_cls=stream_cls,
- response_headers=None,
- )
+ if options.follow_redirects is not None:
+ kwargs["follow_redirects"] = options.follow_redirects
- log.debug("Raising timeout error")
- raise APITimeoutError(request=request) from err
- except Exception as err:
- log.debug("Encountered Exception", exc_info=True)
+ log.debug("Sending HTTP Request: %s %s", request.method, request.url)
- if remaining_retries > 0:
- return self._retry_request(
- input_options,
- cast_to,
- retries_taken=retries_taken,
- stream=stream,
- stream_cls=stream_cls,
- response_headers=None,
+ response = None
+ try:
+ response = self._client.send(
+ request,
+ stream=stream or self._should_stream_response_body(request=request),
+ **kwargs,
)
+ except httpx.TimeoutException as err:
+ log.debug("Encountered httpx.TimeoutException", exc_info=True)
+
+ if remaining_retries > 0:
+ self._sleep_for_retry(
+ retries_taken=retries_taken,
+ max_retries=max_retries,
+ options=input_options,
+ response=None,
+ )
+ continue
+
+ log.debug("Raising timeout error")
+ raise APITimeoutError(request=request) from err
+ except Exception as err:
+ log.debug("Encountered Exception", exc_info=True)
+
+ if remaining_retries > 0:
+ self._sleep_for_retry(
+ retries_taken=retries_taken,
+ max_retries=max_retries,
+ options=input_options,
+ response=None,
+ )
+ continue
+
+ log.debug("Raising connection error")
+ raise APIConnectionError(request=request) from err
+
+ log.debug(
+ 'HTTP Response: %s %s "%i %s" %s',
+ request.method,
+ request.url,
+ response.status_code,
+ response.reason_phrase,
+ response.headers,
+ )
- log.debug("Raising connection error")
- raise APIConnectionError(request=request) from err
-
- log.debug(
- 'HTTP Response: %s %s "%i %s" %s',
- request.method,
- request.url,
- response.status_code,
- response.reason_phrase,
- response.headers,
- )
+ try:
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code
+ log.debug("Encountered httpx.HTTPStatusError", exc_info=True)
+
+ if remaining_retries > 0 and self._should_retry(err.response):
+ err.response.close()
+ self._sleep_for_retry(
+ retries_taken=retries_taken,
+ max_retries=max_retries,
+ options=input_options,
+ response=response,
+ )
+ continue
- try:
- response.raise_for_status()
- except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code
- log.debug("Encountered httpx.HTTPStatusError", exc_info=True)
-
- if remaining_retries > 0 and self._should_retry(err.response):
- err.response.close()
- return self._retry_request(
- input_options,
- cast_to,
- retries_taken=retries_taken,
- response_headers=err.response.headers,
- stream=stream,
- stream_cls=stream_cls,
- )
+ # If the response is streamed then we need to explicitly read the response
+ # to completion before attempting to access the response text.
+ if not err.response.is_closed:
+ err.response.read()
- # If the response is streamed then we need to explicitly read the response
- # to completion before attempting to access the response text.
- if not err.response.is_closed:
- err.response.read()
+ log.debug("Re-raising status error")
+ raise self._make_status_error_from_response(err.response) from None
- log.debug("Re-raising status error")
- raise self._make_status_error_from_response(err.response) from None
+ break
+ assert response is not None, "could not resolve response (should never happen)"
return self._process_response(
cast_to=cast_to,
options=options,
@@ -1069,37 +1055,20 @@ def _request(
retries_taken=retries_taken,
)
- def _retry_request(
- self,
- options: FinalRequestOptions,
- cast_to: Type[ResponseT],
- *,
- retries_taken: int,
- response_headers: httpx.Headers | None,
- stream: bool,
- stream_cls: type[_StreamT] | None,
- ) -> ResponseT | _StreamT:
- remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
+ def _sleep_for_retry(
+ self, *, retries_taken: int, max_retries: int, options: FinalRequestOptions, response: httpx.Response | None
+ ) -> None:
+ remaining_retries = max_retries - retries_taken
if remaining_retries == 1:
log.debug("1 retry left")
else:
log.debug("%i retries left", remaining_retries)
- timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers)
+ timeout = self._calculate_retry_timeout(remaining_retries, options, response.headers if response else None)
log.info("Retrying request to %s in %f seconds", options.url, timeout)
- # In a synchronous context we are blocking the entire thread. Up to the library user to run the client in a
- # different thread if necessary.
time.sleep(timeout)
- return self._request(
- options=options,
- cast_to=cast_to,
- retries_taken=retries_taken + 1,
- stream=stream,
- stream_cls=stream_cls,
- )
-
def _process_response(
self,
*,
@@ -1112,7 +1081,14 @@ def _process_response(
) -> ResponseT:
origin = get_origin(cast_to) or cast_to
- if inspect.isclass(origin) and issubclass(origin, BaseAPIResponse):
+ if (
+ inspect.isclass(origin)
+ and issubclass(origin, BaseAPIResponse)
+ # we only want to actually return the custom BaseAPIResponse class if we're
+ # returning the raw response, or if we're not streaming SSE, as if we're streaming
+ # SSE then `cast_to` doesn't actively reflect the type we need to parse into
+ and (not stream or bool(response.request.headers.get(RAW_RESPONSE_HEADER)))
+ ):
if not issubclass(origin, APIResponse):
raise TypeError(f"API Response types must subclass {APIResponse}; Received {origin}")
@@ -1271,9 +1247,12 @@ def patch(
*,
cast_to: Type[ResponseT],
body: Body | None = None,
+ files: RequestFiles | None = None,
options: RequestOptions = {},
) -> ResponseT:
- opts = FinalRequestOptions.construct(method="patch", url=path, json_data=body, **options)
+ opts = FinalRequestOptions.construct(
+ method="patch", url=path, json_data=body, files=to_httpx_files(files), **options
+ )
return self.request(cast_to, opts)
def put(
@@ -1323,6 +1302,24 @@ def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
+try:
+ import httpx_aiohttp
+except ImportError:
+
+ class _DefaultAioHttpClient(httpx.AsyncClient):
+ def __init__(self, **_kwargs: Any) -> None:
+ raise RuntimeError("To use the aiohttp client you must have installed the package with the `aiohttp` extra")
+else:
+
+ class _DefaultAioHttpClient(httpx_aiohttp.HttpxAiohttpClient): # type: ignore
+ def __init__(self, **kwargs: Any) -> None:
+ kwargs.setdefault("timeout", DEFAULT_TIMEOUT)
+ kwargs.setdefault("limits", DEFAULT_CONNECTION_LIMITS)
+ kwargs.setdefault("follow_redirects", True)
+
+ super().__init__(**kwargs)
+
+
if TYPE_CHECKING:
DefaultAsyncHttpxClient = httpx.AsyncClient
"""An alias to `httpx.AsyncClient` that provides the same defaults that this SDK
@@ -1331,8 +1328,12 @@ def __init__(self, **kwargs: Any) -> None:
This is useful because overriding the `http_client` with your own instance of
`httpx.AsyncClient` will result in httpx's defaults being used, not ours.
"""
+
+ DefaultAioHttpClient = httpx.AsyncClient
+ """An alias to `httpx.AsyncClient` that changes the default HTTP transport to `aiohttp`."""
else:
DefaultAsyncHttpxClient = _DefaultAsyncHttpxClient
+ DefaultAioHttpClient = _DefaultAioHttpClient
class AsyncHttpxClientWrapper(DefaultAsyncHttpxClient):
@@ -1358,46 +1359,11 @@ def __init__(
base_url: str | URL,
_strict_response_validation: bool,
max_retries: int = DEFAULT_MAX_RETRIES,
- timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
- transport: AsyncTransport | None = None,
- proxies: ProxiesTypes | None = None,
- limits: Limits | None = None,
+ timeout: float | Timeout | None | NotGiven = not_given,
http_client: httpx.AsyncClient | None = None,
custom_headers: Mapping[str, str] | None = None,
custom_query: Mapping[str, object] | None = None,
) -> None:
- kwargs: dict[str, Any] = {}
- if limits is not None:
- warnings.warn(
- "The `connection_pool_limits` argument is deprecated. The `http_client` argument should be passed instead",
- category=DeprecationWarning,
- stacklevel=3,
- )
- if http_client is not None:
- raise ValueError("The `http_client` argument is mutually exclusive with `connection_pool_limits`")
- else:
- limits = DEFAULT_CONNECTION_LIMITS
-
- if transport is not None:
- kwargs["transport"] = transport
- warnings.warn(
- "The `transport` argument is deprecated. The `http_client` argument should be passed instead",
- category=DeprecationWarning,
- stacklevel=3,
- )
- if http_client is not None:
- raise ValueError("The `http_client` argument is mutually exclusive with `transport`")
-
- if proxies is not None:
- kwargs["proxies"] = proxies
- warnings.warn(
- "The `proxies` argument is deprecated. The `http_client` argument should be passed instead",
- category=DeprecationWarning,
- stacklevel=3,
- )
- if http_client is not None:
- raise ValueError("The `http_client` argument is mutually exclusive with `proxies`")
-
if not is_given(timeout):
# if the user passed in a custom http client with a non-default
# timeout set then we use that timeout.
@@ -1419,11 +1385,8 @@ def __init__(
super().__init__(
version=version,
base_url=base_url,
- limits=limits,
# cast to a valid type because mypy doesn't understand our type narrowing
timeout=cast(Timeout, timeout),
- proxies=proxies,
- transport=transport,
max_retries=max_retries,
custom_query=custom_query,
custom_headers=custom_headers,
@@ -1433,9 +1396,6 @@ def __init__(
base_url=base_url,
# cast to a valid type because mypy doesn't understand our type narrowing
timeout=cast(Timeout, timeout),
- limits=limits,
- follow_redirects=True,
- **kwargs, # type: ignore
)
def is_closed(self) -> bool:
@@ -1484,7 +1444,6 @@ async def request(
options: FinalRequestOptions,
*,
stream: Literal[False] = False,
- remaining_retries: Optional[int] = None,
) -> ResponseT: ...
@overload
@@ -1495,7 +1454,6 @@ async def request(
*,
stream: Literal[True],
stream_cls: type[_AsyncStreamT],
- remaining_retries: Optional[int] = None,
) -> _AsyncStreamT: ...
@overload
@@ -1506,7 +1464,6 @@ async def request(
*,
stream: bool,
stream_cls: type[_AsyncStreamT] | None = None,
- remaining_retries: Optional[int] = None,
) -> ResponseT | _AsyncStreamT: ...
async def request(
@@ -1516,116 +1473,114 @@ async def request(
*,
stream: bool = False,
stream_cls: type[_AsyncStreamT] | None = None,
- remaining_retries: Optional[int] = None,
- ) -> ResponseT | _AsyncStreamT:
- if remaining_retries is not None:
- retries_taken = options.get_max_retries(self.max_retries) - remaining_retries
- else:
- retries_taken = 0
-
- return await self._request(
- cast_to=cast_to,
- options=options,
- stream=stream,
- stream_cls=stream_cls,
- retries_taken=retries_taken,
- )
-
- async def _request(
- self,
- cast_to: Type[ResponseT],
- options: FinalRequestOptions,
- *,
- stream: bool,
- stream_cls: type[_AsyncStreamT] | None,
- retries_taken: int,
) -> ResponseT | _AsyncStreamT:
if self._platform is None:
# `get_platform` can make blocking IO calls so we
# execute it earlier while we are in an async context
self._platform = await asyncify(get_platform)()
+ cast_to = self._maybe_override_cast_to(cast_to, options)
+
# create a copy of the options we were given so that if the
# options are mutated later & we then retry, the retries are
# given the original options
input_options = model_copy(options)
+ if input_options.idempotency_key is None and input_options.method.lower() != "get":
+ # ensure the idempotency key is reused between requests
+ input_options.idempotency_key = self._idempotency_key()
- cast_to = self._maybe_override_cast_to(cast_to, options)
- options = await self._prepare_options(options)
+ response: httpx.Response | None = None
+ max_retries = input_options.get_max_retries(self.max_retries)
- remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
- request = self._build_request(options, retries_taken=retries_taken)
- await self._prepare_request(request)
+ retries_taken = 0
+ for retries_taken in range(max_retries + 1):
+ options = model_copy(input_options)
+ options = await self._prepare_options(options)
- kwargs: HttpxSendArgs = {}
- if self.custom_auth is not None:
- kwargs["auth"] = self.custom_auth
+ remaining_retries = max_retries - retries_taken
+ request = self._build_request(options, retries_taken=retries_taken)
+ await self._prepare_request(request)
- try:
- response = await self._client.send(
- request,
- stream=stream or self._should_stream_response_body(request=request),
- **kwargs,
- )
- except httpx.TimeoutException as err:
- log.debug("Encountered httpx.TimeoutException", exc_info=True)
+ kwargs: HttpxSendArgs = {}
+ if self.custom_auth is not None:
+ kwargs["auth"] = self.custom_auth
- if remaining_retries > 0:
- return await self._retry_request(
- input_options,
- cast_to,
- retries_taken=retries_taken,
- stream=stream,
- stream_cls=stream_cls,
- response_headers=None,
- )
+ if options.follow_redirects is not None:
+ kwargs["follow_redirects"] = options.follow_redirects
- log.debug("Raising timeout error")
- raise APITimeoutError(request=request) from err
- except Exception as err:
- log.debug("Encountered Exception", exc_info=True)
+ log.debug("Sending HTTP Request: %s %s", request.method, request.url)
- if remaining_retries > 0:
- return await self._retry_request(
- input_options,
- cast_to,
- retries_taken=retries_taken,
- stream=stream,
- stream_cls=stream_cls,
- response_headers=None,
+ response = None
+ try:
+ response = await self._client.send(
+ request,
+ stream=stream or self._should_stream_response_body(request=request),
+ **kwargs,
)
+ except httpx.TimeoutException as err:
+ log.debug("Encountered httpx.TimeoutException", exc_info=True)
+
+ if remaining_retries > 0:
+ await self._sleep_for_retry(
+ retries_taken=retries_taken,
+ max_retries=max_retries,
+ options=input_options,
+ response=None,
+ )
+ continue
+
+ log.debug("Raising timeout error")
+ raise APITimeoutError(request=request) from err
+ except Exception as err:
+ log.debug("Encountered Exception", exc_info=True)
+
+ if remaining_retries > 0:
+ await self._sleep_for_retry(
+ retries_taken=retries_taken,
+ max_retries=max_retries,
+ options=input_options,
+ response=None,
+ )
+ continue
+
+ log.debug("Raising connection error")
+ raise APIConnectionError(request=request) from err
+
+ log.debug(
+ 'HTTP Response: %s %s "%i %s" %s',
+ request.method,
+ request.url,
+ response.status_code,
+ response.reason_phrase,
+ response.headers,
+ )
- log.debug("Raising connection error")
- raise APIConnectionError(request=request) from err
+ try:
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code
+ log.debug("Encountered httpx.HTTPStatusError", exc_info=True)
+
+ if remaining_retries > 0 and self._should_retry(err.response):
+ await err.response.aclose()
+ await self._sleep_for_retry(
+ retries_taken=retries_taken,
+ max_retries=max_retries,
+ options=input_options,
+ response=response,
+ )
+ continue
- log.debug(
- 'HTTP Request: %s %s "%i %s"', request.method, request.url, response.status_code, response.reason_phrase
- )
+ # If the response is streamed then we need to explicitly read the response
+ # to completion before attempting to access the response text.
+ if not err.response.is_closed:
+ await err.response.aread()
- try:
- response.raise_for_status()
- except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code
- log.debug("Encountered httpx.HTTPStatusError", exc_info=True)
-
- if remaining_retries > 0 and self._should_retry(err.response):
- await err.response.aclose()
- return await self._retry_request(
- input_options,
- cast_to,
- retries_taken=retries_taken,
- response_headers=err.response.headers,
- stream=stream,
- stream_cls=stream_cls,
- )
-
- # If the response is streamed then we need to explicitly read the response
- # to completion before attempting to access the response text.
- if not err.response.is_closed:
- await err.response.aread()
+ log.debug("Re-raising status error")
+ raise self._make_status_error_from_response(err.response) from None
- log.debug("Re-raising status error")
- raise self._make_status_error_from_response(err.response) from None
+ break
+ assert response is not None, "could not resolve response (should never happen)"
return await self._process_response(
cast_to=cast_to,
options=options,
@@ -1635,35 +1590,20 @@ async def _request(
retries_taken=retries_taken,
)
- async def _retry_request(
- self,
- options: FinalRequestOptions,
- cast_to: Type[ResponseT],
- *,
- retries_taken: int,
- response_headers: httpx.Headers | None,
- stream: bool,
- stream_cls: type[_AsyncStreamT] | None,
- ) -> ResponseT | _AsyncStreamT:
- remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
+ async def _sleep_for_retry(
+ self, *, retries_taken: int, max_retries: int, options: FinalRequestOptions, response: httpx.Response | None
+ ) -> None:
+ remaining_retries = max_retries - retries_taken
if remaining_retries == 1:
log.debug("1 retry left")
else:
log.debug("%i retries left", remaining_retries)
- timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers)
+ timeout = self._calculate_retry_timeout(remaining_retries, options, response.headers if response else None)
log.info("Retrying request to %s in %f seconds", options.url, timeout)
await anyio.sleep(timeout)
- return await self._request(
- options=options,
- cast_to=cast_to,
- retries_taken=retries_taken + 1,
- stream=stream,
- stream_cls=stream_cls,
- )
-
async def _process_response(
self,
*,
@@ -1676,7 +1616,14 @@ async def _process_response(
) -> ResponseT:
origin = get_origin(cast_to) or cast_to
- if inspect.isclass(origin) and issubclass(origin, BaseAPIResponse):
+ if (
+ inspect.isclass(origin)
+ and issubclass(origin, BaseAPIResponse)
+ # we only want to actually return the custom BaseAPIResponse class if we're
+ # returning the raw response, or if we're not streaming SSE, as if we're streaming
+ # SSE then `cast_to` doesn't actively reflect the type we need to parse into
+ and (not stream or bool(response.request.headers.get(RAW_RESPONSE_HEADER)))
+ ):
if not issubclass(origin, AsyncAPIResponse):
raise TypeError(f"API Response types must subclass {AsyncAPIResponse}; Received {origin}")
@@ -1823,9 +1770,12 @@ async def patch(
*,
cast_to: Type[ResponseT],
body: Body | None = None,
+ files: RequestFiles | None = None,
options: RequestOptions = {},
) -> ResponseT:
- opts = FinalRequestOptions.construct(method="patch", url=path, json_data=body, **options)
+ opts = FinalRequestOptions.construct(
+ method="patch", url=path, json_data=body, files=await async_to_httpx_files(files), **options
+ )
return await self.request(cast_to, opts)
async def put(
@@ -1874,8 +1824,8 @@ def make_request_options(
extra_query: Query | None = None,
extra_body: Body | None = None,
idempotency_key: str | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- post_parser: PostParser | NotGiven = NOT_GIVEN,
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
+ post_parser: PostParser | NotGiven = not_given,
) -> RequestOptions:
"""Create a dict of type RequestOptions without keys of NotGiven values."""
options: RequestOptions = {}
diff --git a/src/brainbase/_client.py b/src/brainbase/_client.py
index c42e10d6..541ea815 100644
--- a/src/brainbase/_client.py
+++ b/src/brainbase/_client.py
@@ -3,7 +3,7 @@
from __future__ import annotations
import os
-from typing import Any, Union, Mapping
+from typing import TYPE_CHECKING, Any, Mapping
from typing_extensions import Self, override
import httpx
@@ -11,18 +11,16 @@
from . import _exceptions
from ._qs import Querystring
from ._types import (
- NOT_GIVEN,
Omit,
Timeout,
NotGiven,
Transport,
ProxiesTypes,
RequestOptions,
+ not_given,
)
-from ._utils import (
- is_given,
- get_async_library,
-)
+from ._utils import is_given, get_async_library
+from ._compat import cached_property
from ._version import __version__
from ._streaming import Stream as Stream, AsyncStream as AsyncStream
from ._exceptions import APIStatusError, BrainbaseError
@@ -31,7 +29,10 @@
SyncAPIClient,
AsyncAPIClient,
)
-from .resources.workers import workers
+
+if TYPE_CHECKING:
+ from .resources import workers
+ from .resources.workers.workers import WorkersResource, AsyncWorkersResource
__all__ = [
"Timeout",
@@ -46,10 +47,6 @@
class Brainbase(SyncAPIClient):
- workers: workers.WorkersResource
- with_raw_response: BrainbaseWithRawResponse
- with_streaming_response: BrainbaseWithStreamedResponse
-
# client options
api_key: str
@@ -58,7 +55,7 @@ def __init__(
*,
api_key: str | None = None,
base_url: str | httpx.URL | None = None,
- timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
+ timeout: float | Timeout | None | NotGiven = not_given,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
@@ -76,7 +73,7 @@ def __init__(
# part of our public interface in the future.
_strict_response_validation: bool = False,
) -> None:
- """Construct a new synchronous brainbase client instance.
+ """Construct a new synchronous Brainbase client instance.
This automatically infers the `api_key` argument from the `API_KEY` environment variable if it is not provided.
"""
@@ -104,9 +101,19 @@ def __init__(
_strict_response_validation=_strict_response_validation,
)
- self.workers = workers.WorkersResource(self)
- self.with_raw_response = BrainbaseWithRawResponse(self)
- self.with_streaming_response = BrainbaseWithStreamedResponse(self)
+ @cached_property
+ def workers(self) -> WorkersResource:
+ from .resources.workers import WorkersResource
+
+ return WorkersResource(self)
+
+ @cached_property
+ def with_raw_response(self) -> BrainbaseWithRawResponse:
+ return BrainbaseWithRawResponse(self)
+
+ @cached_property
+ def with_streaming_response(self) -> BrainbaseWithStreamedResponse:
+ return BrainbaseWithStreamedResponse(self)
@property
@override
@@ -133,9 +140,9 @@ def copy(
*,
api_key: str | None = None,
base_url: str | httpx.URL | None = None,
- timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
+ timeout: float | Timeout | None | NotGiven = not_given,
http_client: httpx.Client | None = None,
- max_retries: int | NotGiven = NOT_GIVEN,
+ max_retries: int | NotGiven = not_given,
default_headers: Mapping[str, str] | None = None,
set_default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
@@ -214,10 +221,6 @@ def _make_status_error(
class AsyncBrainbase(AsyncAPIClient):
- workers: workers.AsyncWorkersResource
- with_raw_response: AsyncBrainbaseWithRawResponse
- with_streaming_response: AsyncBrainbaseWithStreamedResponse
-
# client options
api_key: str
@@ -226,7 +229,7 @@ def __init__(
*,
api_key: str | None = None,
base_url: str | httpx.URL | None = None,
- timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
+ timeout: float | Timeout | None | NotGiven = not_given,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
@@ -244,7 +247,7 @@ def __init__(
# part of our public interface in the future.
_strict_response_validation: bool = False,
) -> None:
- """Construct a new async brainbase client instance.
+ """Construct a new async AsyncBrainbase client instance.
This automatically infers the `api_key` argument from the `API_KEY` environment variable if it is not provided.
"""
@@ -272,9 +275,19 @@ def __init__(
_strict_response_validation=_strict_response_validation,
)
- self.workers = workers.AsyncWorkersResource(self)
- self.with_raw_response = AsyncBrainbaseWithRawResponse(self)
- self.with_streaming_response = AsyncBrainbaseWithStreamedResponse(self)
+ @cached_property
+ def workers(self) -> AsyncWorkersResource:
+ from .resources.workers import AsyncWorkersResource
+
+ return AsyncWorkersResource(self)
+
+ @cached_property
+ def with_raw_response(self) -> AsyncBrainbaseWithRawResponse:
+ return AsyncBrainbaseWithRawResponse(self)
+
+ @cached_property
+ def with_streaming_response(self) -> AsyncBrainbaseWithStreamedResponse:
+ return AsyncBrainbaseWithStreamedResponse(self)
@property
@override
@@ -301,9 +314,9 @@ def copy(
*,
api_key: str | None = None,
base_url: str | httpx.URL | None = None,
- timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
+ timeout: float | Timeout | None | NotGiven = not_given,
http_client: httpx.AsyncClient | None = None,
- max_retries: int | NotGiven = NOT_GIVEN,
+ max_retries: int | NotGiven = not_given,
default_headers: Mapping[str, str] | None = None,
set_default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
@@ -382,23 +395,55 @@ def _make_status_error(
class BrainbaseWithRawResponse:
+ _client: Brainbase
+
def __init__(self, client: Brainbase) -> None:
- self.workers = workers.WorkersResourceWithRawResponse(client.workers)
+ self._client = client
+
+ @cached_property
+ def workers(self) -> workers.WorkersResourceWithRawResponse:
+ from .resources.workers import WorkersResourceWithRawResponse
+
+ return WorkersResourceWithRawResponse(self._client.workers)
class AsyncBrainbaseWithRawResponse:
+ _client: AsyncBrainbase
+
def __init__(self, client: AsyncBrainbase) -> None:
- self.workers = workers.AsyncWorkersResourceWithRawResponse(client.workers)
+ self._client = client
+
+ @cached_property
+ def workers(self) -> workers.AsyncWorkersResourceWithRawResponse:
+ from .resources.workers import AsyncWorkersResourceWithRawResponse
+
+ return AsyncWorkersResourceWithRawResponse(self._client.workers)
class BrainbaseWithStreamedResponse:
+ _client: Brainbase
+
def __init__(self, client: Brainbase) -> None:
- self.workers = workers.WorkersResourceWithStreamingResponse(client.workers)
+ self._client = client
+
+ @cached_property
+ def workers(self) -> workers.WorkersResourceWithStreamingResponse:
+ from .resources.workers import WorkersResourceWithStreamingResponse
+
+ return WorkersResourceWithStreamingResponse(self._client.workers)
class AsyncBrainbaseWithStreamedResponse:
+ _client: AsyncBrainbase
+
def __init__(self, client: AsyncBrainbase) -> None:
- self.workers = workers.AsyncWorkersResourceWithStreamingResponse(client.workers)
+ self._client = client
+
+ @cached_property
+ def workers(self) -> workers.AsyncWorkersResourceWithStreamingResponse:
+ from .resources.workers import AsyncWorkersResourceWithStreamingResponse
+
+ return AsyncWorkersResourceWithStreamingResponse(self._client.workers)
Client = Brainbase
diff --git a/src/brainbase/_compat.py b/src/brainbase/_compat.py
index 92d9ee61..bdef67f0 100644
--- a/src/brainbase/_compat.py
+++ b/src/brainbase/_compat.py
@@ -12,14 +12,13 @@
_T = TypeVar("_T")
_ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel)
-# --------------- Pydantic v2 compatibility ---------------
+# --------------- Pydantic v2, v3 compatibility ---------------
# Pyright incorrectly reports some of our functions as overriding a method when they don't
# pyright: reportIncompatibleMethodOverride=false
-PYDANTIC_V2 = pydantic.VERSION.startswith("2.")
+PYDANTIC_V1 = pydantic.VERSION.startswith("1.")
-# v1 re-exports
if TYPE_CHECKING:
def parse_date(value: date | StrBytesIntFloat) -> date: # noqa: ARG001
@@ -44,90 +43,92 @@ def is_typeddict(type_: type[Any]) -> bool: # noqa: ARG001
...
else:
- if PYDANTIC_V2:
- from pydantic.v1.typing import (
+ # v1 re-exports
+ if PYDANTIC_V1:
+ from pydantic.typing import (
get_args as get_args,
is_union as is_union,
get_origin as get_origin,
is_typeddict as is_typeddict,
is_literal_type as is_literal_type,
)
- from pydantic.v1.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime
+ from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime
else:
- from pydantic.typing import (
+ from ._utils import (
get_args as get_args,
is_union as is_union,
get_origin as get_origin,
+ parse_date as parse_date,
is_typeddict as is_typeddict,
+ parse_datetime as parse_datetime,
is_literal_type as is_literal_type,
)
- from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime
# refactored config
if TYPE_CHECKING:
from pydantic import ConfigDict as ConfigDict
else:
- if PYDANTIC_V2:
- from pydantic import ConfigDict
- else:
+ if PYDANTIC_V1:
# TODO: provide an error message here?
ConfigDict = None
+ else:
+ from pydantic import ConfigDict as ConfigDict
# renamed methods / properties
def parse_obj(model: type[_ModelT], value: object) -> _ModelT:
- if PYDANTIC_V2:
- return model.model_validate(value)
- else:
+ if PYDANTIC_V1:
return cast(_ModelT, model.parse_obj(value)) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
+ else:
+ return model.model_validate(value)
def field_is_required(field: FieldInfo) -> bool:
- if PYDANTIC_V2:
- return field.is_required()
- return field.required # type: ignore
+ if PYDANTIC_V1:
+ return field.required # type: ignore
+ return field.is_required()
def field_get_default(field: FieldInfo) -> Any:
value = field.get_default()
- if PYDANTIC_V2:
- from pydantic_core import PydanticUndefined
-
- if value == PydanticUndefined:
- return None
+ if PYDANTIC_V1:
return value
+ from pydantic_core import PydanticUndefined
+
+ if value == PydanticUndefined:
+ return None
return value
def field_outer_type(field: FieldInfo) -> Any:
- if PYDANTIC_V2:
- return field.annotation
- return field.outer_type_ # type: ignore
+ if PYDANTIC_V1:
+ return field.outer_type_ # type: ignore
+ return field.annotation
def get_model_config(model: type[pydantic.BaseModel]) -> Any:
- if PYDANTIC_V2:
- return model.model_config
- return model.__config__ # type: ignore
+ if PYDANTIC_V1:
+ return model.__config__ # type: ignore
+ return model.model_config
def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]:
- if PYDANTIC_V2:
- return model.model_fields
- return model.__fields__ # type: ignore
+ if PYDANTIC_V1:
+ return model.__fields__ # type: ignore
+ return model.model_fields
def model_copy(model: _ModelT, *, deep: bool = False) -> _ModelT:
- if PYDANTIC_V2:
- return model.model_copy(deep=deep)
- return model.copy(deep=deep) # type: ignore
+ if PYDANTIC_V1:
+ return model.copy(deep=deep) # type: ignore
+ return model.model_copy(deep=deep)
def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str:
- if PYDANTIC_V2:
- return model.model_dump_json(indent=indent)
- return model.json(indent=indent) # type: ignore
+ if PYDANTIC_V1:
+ return model.json(indent=indent) # type: ignore
+ return model.model_dump_json(indent=indent)
def model_dump(
@@ -139,14 +140,14 @@ def model_dump(
warnings: bool = True,
mode: Literal["json", "python"] = "python",
) -> dict[str, Any]:
- if PYDANTIC_V2 or hasattr(model, "model_dump"):
+ if (not PYDANTIC_V1) or hasattr(model, "model_dump"):
return model.model_dump(
mode=mode,
exclude=exclude,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
# warnings are not supported in Pydantic v1
- warnings=warnings if PYDANTIC_V2 else True,
+ warnings=True if PYDANTIC_V1 else warnings,
)
return cast(
"dict[str, Any]",
@@ -159,9 +160,9 @@ def model_dump(
def model_parse(model: type[_ModelT], data: Any) -> _ModelT:
- if PYDANTIC_V2:
- return model.model_validate(data)
- return model.parse_obj(data) # pyright: ignore[reportDeprecated]
+ if PYDANTIC_V1:
+ return model.parse_obj(data) # pyright: ignore[reportDeprecated]
+ return model.model_validate(data)
# generic models
@@ -170,17 +171,16 @@ def model_parse(model: type[_ModelT], data: Any) -> _ModelT:
class GenericModel(pydantic.BaseModel): ...
else:
- if PYDANTIC_V2:
+ if PYDANTIC_V1:
+ import pydantic.generics
+
+ class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): ...
+ else:
# there no longer needs to be a distinction in v2 but
# we still have to create our own subclass to avoid
# inconsistent MRO ordering errors
class GenericModel(pydantic.BaseModel): ...
- else:
- import pydantic.generics
-
- class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): ...
-
# cached properties
if TYPE_CHECKING:
diff --git a/src/brainbase/_files.py b/src/brainbase/_files.py
index 715cc207..cc14c14f 100644
--- a/src/brainbase/_files.py
+++ b/src/brainbase/_files.py
@@ -69,12 +69,12 @@ def _transform_file(file: FileTypes) -> HttpxFileTypes:
return file
if is_tuple_t(file):
- return (file[0], _read_file_content(file[1]), *file[2:])
+ return (file[0], read_file_content(file[1]), *file[2:])
raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")
-def _read_file_content(file: FileContent) -> HttpxFileContent:
+def read_file_content(file: FileContent) -> HttpxFileContent:
if isinstance(file, os.PathLike):
return pathlib.Path(file).read_bytes()
return file
@@ -111,12 +111,12 @@ async def _async_transform_file(file: FileTypes) -> HttpxFileTypes:
return file
if is_tuple_t(file):
- return (file[0], await _async_read_file_content(file[1]), *file[2:])
+ return (file[0], await async_read_file_content(file[1]), *file[2:])
raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")
-async def _async_read_file_content(file: FileContent) -> HttpxFileContent:
+async def async_read_file_content(file: FileContent) -> HttpxFileContent:
if isinstance(file, os.PathLike):
return await anyio.Path(file).read_bytes()
diff --git a/src/brainbase/_models.py b/src/brainbase/_models.py
index 12c34b7d..ca9500b2 100644
--- a/src/brainbase/_models.py
+++ b/src/brainbase/_models.py
@@ -2,9 +2,11 @@
import os
import inspect
-from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, cast
+import weakref
+from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, Optional, cast
from datetime import date, datetime
from typing_extensions import (
+ List,
Unpack,
Literal,
ClassVar,
@@ -19,7 +21,6 @@
)
import pydantic
-import pydantic.generics
from pydantic.fields import FieldInfo
from ._types import (
@@ -50,7 +51,7 @@
strip_annotated_type,
)
from ._compat import (
- PYDANTIC_V2,
+ PYDANTIC_V1,
ConfigDict,
GenericModel as BaseGenericModel,
get_args,
@@ -65,7 +66,7 @@
from ._constants import RAW_RESPONSE_HEADER
if TYPE_CHECKING:
- from pydantic_core.core_schema import ModelField, LiteralSchema, ModelFieldsSchema
+ from pydantic_core.core_schema import ModelField, ModelSchema, LiteralSchema, ModelFieldsSchema
__all__ = ["BaseModel", "GenericModel"]
@@ -81,11 +82,7 @@ class _ConfigProtocol(Protocol):
class BaseModel(pydantic.BaseModel):
- if PYDANTIC_V2:
- model_config: ClassVar[ConfigDict] = ConfigDict(
- extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true"))
- )
- else:
+ if PYDANTIC_V1:
@property
@override
@@ -95,6 +92,10 @@ def model_fields_set(self) -> set[str]:
class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
extra: Any = pydantic.Extra.allow # type: ignore
+ else:
+ model_config: ClassVar[ConfigDict] = ConfigDict(
+ extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true"))
+ )
def to_dict(
self,
@@ -208,28 +209,32 @@ def construct( # pyright: ignore[reportIncompatibleMethodOverride]
else:
fields_values[name] = field_get_default(field)
+ extra_field_type = _get_extra_fields_type(__cls)
+
_extra = {}
for key, value in values.items():
if key not in model_fields:
- if PYDANTIC_V2:
- _extra[key] = value
- else:
+ parsed = construct_type(value=value, type_=extra_field_type) if extra_field_type is not None else value
+
+ if PYDANTIC_V1:
_fields_set.add(key)
- fields_values[key] = value
+ fields_values[key] = parsed
+ else:
+ _extra[key] = parsed
object.__setattr__(m, "__dict__", fields_values)
- if PYDANTIC_V2:
- # these properties are copied from Pydantic's `model_construct()` method
- object.__setattr__(m, "__pydantic_private__", None)
- object.__setattr__(m, "__pydantic_extra__", _extra)
- object.__setattr__(m, "__pydantic_fields_set__", _fields_set)
- else:
+ if PYDANTIC_V1:
# init_private_attributes() does not exist in v2
m._init_private_attributes() # type: ignore
# copied from Pydantic v1's `construct()` method
object.__setattr__(m, "__fields_set__", _fields_set)
+ else:
+ # these properties are copied from Pydantic's `model_construct()` method
+ object.__setattr__(m, "__pydantic_private__", None)
+ object.__setattr__(m, "__pydantic_extra__", _extra)
+ object.__setattr__(m, "__pydantic_fields_set__", _fields_set)
return m
@@ -239,7 +244,7 @@ def construct( # pyright: ignore[reportIncompatibleMethodOverride]
# although not in practice
model_construct = construct
- if not PYDANTIC_V2:
+ if PYDANTIC_V1:
# we define aliases for some of the new pydantic v2 methods so
# that we can just document these methods without having to specify
# a specific pydantic version as some users may not know which
@@ -252,13 +257,15 @@ def model_dump(
mode: Literal["json", "python"] | str = "python",
include: IncEx | None = None,
exclude: IncEx | None = None,
- by_alias: bool = False,
+ context: Any | None = None,
+ by_alias: bool | None = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
+ exclude_computed_fields: bool = False,
round_trip: bool = False,
warnings: bool | Literal["none", "warn", "error"] = True,
- context: dict[str, Any] | None = None,
+ fallback: Callable[[Any], Any] | None = None,
serialize_as_any: bool = False,
) -> dict[str, Any]:
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump
@@ -267,16 +274,24 @@ def model_dump(
Args:
mode: The mode in which `to_python` should run.
- If mode is 'json', the dictionary will only contain JSON serializable types.
- If mode is 'python', the dictionary may contain any Python objects.
- include: A list of fields to include in the output.
- exclude: A list of fields to exclude from the output.
+ If mode is 'json', the output will only contain JSON serializable types.
+ If mode is 'python', the output may contain non-JSON-serializable Python objects.
+ include: A set of fields to include in the output.
+ exclude: A set of fields to exclude from the output.
+ context: Additional context to pass to the serializer.
by_alias: Whether to use the field's alias in the dictionary key if defined.
- exclude_unset: Whether to exclude fields that are unset or None from the output.
- exclude_defaults: Whether to exclude fields that are set to their default value from the output.
- exclude_none: Whether to exclude fields that have a value of `None` from the output.
- round_trip: Whether to enable serialization and deserialization round-trip support.
- warnings: Whether to log warnings when invalid fields are encountered.
+ exclude_unset: Whether to exclude fields that have not been explicitly set.
+ exclude_defaults: Whether to exclude fields that are set to their default value.
+ exclude_none: Whether to exclude fields that have a value of `None`.
+ exclude_computed_fields: Whether to exclude computed fields.
+ While this can be useful for round-tripping, it is usually recommended to use the dedicated
+ `round_trip` parameter instead.
+ round_trip: If True, dumped values should be valid as input for non-idempotent types such as Json[T].
+ warnings: How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors,
+ "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError].
+ fallback: A function to call when an unknown value is encountered. If not provided,
+ a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised.
+ serialize_as_any: Whether to serialize fields with duck-typing serialization behavior.
Returns:
A dictionary representation of the model.
@@ -291,31 +306,38 @@ def model_dump(
raise ValueError("context is only supported in Pydantic v2")
if serialize_as_any != False:
raise ValueError("serialize_as_any is only supported in Pydantic v2")
+ if fallback is not None:
+ raise ValueError("fallback is only supported in Pydantic v2")
+ if exclude_computed_fields != False:
+ raise ValueError("exclude_computed_fields is only supported in Pydantic v2")
dumped = super().dict( # pyright: ignore[reportDeprecated]
include=include,
exclude=exclude,
- by_alias=by_alias,
+ by_alias=by_alias if by_alias is not None else False,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
- return cast(dict[str, Any], json_safe(dumped)) if mode == "json" else dumped
+ return cast("dict[str, Any]", json_safe(dumped)) if mode == "json" else dumped
@override
def model_dump_json(
self,
*,
indent: int | None = None,
+ ensure_ascii: bool = False,
include: IncEx | None = None,
exclude: IncEx | None = None,
- by_alias: bool = False,
+ context: Any | None = None,
+ by_alias: bool | None = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
+ exclude_computed_fields: bool = False,
round_trip: bool = False,
warnings: bool | Literal["none", "warn", "error"] = True,
- context: dict[str, Any] | None = None,
+ fallback: Callable[[Any], Any] | None = None,
serialize_as_any: bool = False,
) -> str:
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json
@@ -344,11 +366,17 @@ def model_dump_json(
raise ValueError("context is only supported in Pydantic v2")
if serialize_as_any != False:
raise ValueError("serialize_as_any is only supported in Pydantic v2")
+ if fallback is not None:
+ raise ValueError("fallback is only supported in Pydantic v2")
+ if ensure_ascii != False:
+ raise ValueError("ensure_ascii is only supported in Pydantic v2")
+ if exclude_computed_fields != False:
+ raise ValueError("exclude_computed_fields is only supported in Pydantic v2")
return super().json( # type: ignore[reportDeprecated]
indent=indent,
include=include,
exclude=exclude,
- by_alias=by_alias,
+ by_alias=by_alias if by_alias is not None else False,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
@@ -359,15 +387,32 @@ def _construct_field(value: object, field: FieldInfo, key: str) -> object:
if value is None:
return field_get_default(field)
- if PYDANTIC_V2:
- type_ = field.annotation
- else:
+ if PYDANTIC_V1:
type_ = cast(type, field.outer_type_) # type: ignore
+ else:
+ type_ = field.annotation # type: ignore
if type_ is None:
raise RuntimeError(f"Unexpected field type is None for {key}")
- return construct_type(value=value, type_=type_)
+ return construct_type(value=value, type_=type_, metadata=getattr(field, "metadata", None))
+
+
+def _get_extra_fields_type(cls: type[pydantic.BaseModel]) -> type | None:
+ if PYDANTIC_V1:
+ # TODO
+ return None
+
+ schema = cls.__pydantic_core_schema__
+ if schema["type"] == "model":
+ fields = schema["schema"]
+ if fields["type"] == "model-fields":
+ extras = fields.get("extras_schema")
+ if extras and "cls" in extras:
+ # mypy can't narrow the type
+ return extras["cls"] # type: ignore[no-any-return]
+
+ return None
def is_basemodel(type_: type) -> bool:
@@ -421,20 +466,28 @@ def construct_type_unchecked(*, value: object, type_: type[_T]) -> _T:
return cast(_T, construct_type(value=value, type_=type_))
-def construct_type(*, value: object, type_: object) -> object:
+def construct_type(*, value: object, type_: object, metadata: Optional[List[Any]] = None) -> object:
"""Loose coercion to the expected type with construction of nested values.
If the given value does not match the expected type then it is returned as-is.
"""
+
+ # store a reference to the original type we were given before we extract any inner
+ # types so that we can properly resolve forward references in `TypeAliasType` annotations
+ original_type = None
+
# we allow `object` as the input type because otherwise, passing things like
# `Literal['value']` will be reported as a type error by type checkers
type_ = cast("type[object]", type_)
if is_type_alias_type(type_):
+ original_type = type_ # type: ignore[unreachable]
type_ = type_.__value__ # type: ignore[unreachable]
# unwrap `Annotated[T, ...]` -> `T`
- if is_annotated_type(type_):
- meta: tuple[Any, ...] = get_args(type_)[1:]
+ if metadata is not None and len(metadata) > 0:
+ meta: tuple[Any, ...] = tuple(metadata)
+ elif is_annotated_type(type_):
+ meta = get_args(type_)[1:]
type_ = extract_type_arg(type_, 0)
else:
meta = tuple()
@@ -446,7 +499,7 @@ def construct_type(*, value: object, type_: object) -> object:
if is_union(origin):
try:
- return validate_type(type_=cast("type[object]", type_), value=value)
+ return validate_type(type_=cast("type[object]", original_type or type_), value=value)
except Exception:
pass
@@ -538,6 +591,9 @@ class CachedDiscriminatorType(Protocol):
__discriminator__: DiscriminatorDetails
+DISCRIMINATOR_CACHE: weakref.WeakKeyDictionary[type, DiscriminatorDetails] = weakref.WeakKeyDictionary()
+
+
class DiscriminatorDetails:
field_name: str
"""The name of the discriminator field in the variant class, e.g.
@@ -580,8 +636,9 @@ def __init__(
def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None:
- if isinstance(union, CachedDiscriminatorType):
- return union.__discriminator__
+ cached = DISCRIMINATOR_CACHE.get(union)
+ if cached is not None:
+ return cached
discriminator_field_name: str | None = None
@@ -599,30 +656,30 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any,
for variant in get_args(union):
variant = strip_annotated_type(variant)
if is_basemodel_type(variant):
- if PYDANTIC_V2:
- field = _extract_field_schema_pv2(variant, discriminator_field_name)
- if not field:
+ if PYDANTIC_V1:
+ field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
+ if not field_info:
continue
# Note: if one variant defines an alias then they all should
- discriminator_alias = field.get("serialization_alias")
-
- field_schema = field["schema"]
+ discriminator_alias = field_info.alias
- if field_schema["type"] == "literal":
- for entry in cast("LiteralSchema", field_schema)["expected"]:
+ if (annotation := getattr(field_info, "annotation", None)) and is_literal_type(annotation):
+ for entry in get_args(annotation):
if isinstance(entry, str):
mapping[entry] = variant
else:
- field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
- if not field_info:
+ field = _extract_field_schema_pv2(variant, discriminator_field_name)
+ if not field:
continue
# Note: if one variant defines an alias then they all should
- discriminator_alias = field_info.alias
+ discriminator_alias = field.get("serialization_alias")
+
+ field_schema = field["schema"]
- if field_info.annotation and is_literal_type(field_info.annotation):
- for entry in get_args(field_info.annotation):
+ if field_schema["type"] == "literal":
+ for entry in cast("LiteralSchema", field_schema)["expected"]:
if isinstance(entry, str):
mapping[entry] = variant
@@ -634,21 +691,24 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any,
discriminator_field=discriminator_field_name,
discriminator_alias=discriminator_alias,
)
- cast(CachedDiscriminatorType, union).__discriminator__ = details
+ DISCRIMINATOR_CACHE.setdefault(union, details)
return details
def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None:
schema = model.__pydantic_core_schema__
+ if schema["type"] == "definitions":
+ schema = schema["schema"]
+
if schema["type"] != "model":
return None
+ schema = cast("ModelSchema", schema)
fields_schema = schema["schema"]
if fields_schema["type"] != "model-fields":
return None
fields_schema = cast("ModelFieldsSchema", fields_schema)
-
field = fields_schema["fields"].get(field_name)
if not field:
return None
@@ -672,7 +732,7 @@ def set_pydantic_config(typ: Any, config: pydantic.ConfigDict) -> None:
setattr(typ, "__pydantic_config__", config) # noqa: B010
-# our use of subclasssing here causes weirdness for type checkers,
+# our use of subclassing here causes weirdness for type checkers,
# so we just pretend that we don't subclass
if TYPE_CHECKING:
GenericModel = BaseModel
@@ -682,7 +742,7 @@ class GenericModel(BaseGenericModel, BaseModel):
pass
-if PYDANTIC_V2:
+if not PYDANTIC_V1:
from pydantic import TypeAdapter as _TypeAdapter
_CachedTypeAdapter = cast("TypeAdapter[object]", lru_cache(maxsize=None)(_TypeAdapter))
@@ -729,6 +789,7 @@ class FinalRequestOptionsInput(TypedDict, total=False):
idempotency_key: str
json_data: Body
extra_json: AnyMapping
+ follow_redirects: bool
@final
@@ -742,18 +803,19 @@ class FinalRequestOptions(pydantic.BaseModel):
files: Union[HttpxRequestFiles, None] = None
idempotency_key: Union[str, None] = None
post_parser: Union[Callable[[Any], Any], NotGiven] = NotGiven()
+ follow_redirects: Union[bool, None] = None
# It should be noted that we cannot use `json` here as that would override
# a BaseModel method in an incompatible fashion.
json_data: Union[Body, None] = None
extra_json: Union[AnyMapping, None] = None
- if PYDANTIC_V2:
- model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
- else:
+ if PYDANTIC_V1:
class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
arbitrary_types_allowed: bool = True
+ else:
+ model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
def get_max_retries(self, max_retries: int) -> int:
if isinstance(self.max_retries, NotGiven):
@@ -786,9 +848,9 @@ def construct( # type: ignore
key: strip_not_given(value)
for key, value in values.items()
}
- if PYDANTIC_V2:
- return super().model_construct(_fields_set, **kwargs)
- return cast(FinalRequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated]
+ if PYDANTIC_V1:
+ return cast(FinalRequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated]
+ return super().model_construct(_fields_set, **kwargs)
if not TYPE_CHECKING:
# type checkers incorrectly complain about this assignment
diff --git a/src/brainbase/_qs.py b/src/brainbase/_qs.py
index 274320ca..ada6fd3f 100644
--- a/src/brainbase/_qs.py
+++ b/src/brainbase/_qs.py
@@ -4,7 +4,7 @@
from urllib.parse import parse_qs, urlencode
from typing_extensions import Literal, get_args
-from ._types import NOT_GIVEN, NotGiven, NotGivenOr
+from ._types import NotGiven, not_given
from ._utils import flatten
_T = TypeVar("_T")
@@ -41,8 +41,8 @@ def stringify(
self,
params: Params,
*,
- array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN,
- nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN,
+ array_format: ArrayFormat | NotGiven = not_given,
+ nested_format: NestedFormat | NotGiven = not_given,
) -> str:
return urlencode(
self.stringify_items(
@@ -56,8 +56,8 @@ def stringify_items(
self,
params: Params,
*,
- array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN,
- nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN,
+ array_format: ArrayFormat | NotGiven = not_given,
+ nested_format: NestedFormat | NotGiven = not_given,
) -> list[tuple[str, str]]:
opts = Options(
qs=self,
@@ -143,8 +143,8 @@ def __init__(
self,
qs: Querystring = _qs,
*,
- array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN,
- nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN,
+ array_format: ArrayFormat | NotGiven = not_given,
+ nested_format: NestedFormat | NotGiven = not_given,
) -> None:
self.array_format = qs.array_format if isinstance(array_format, NotGiven) else array_format
self.nested_format = qs.nested_format if isinstance(nested_format, NotGiven) else nested_format
diff --git a/src/brainbase/_response.py b/src/brainbase/_response.py
index 5c0e8f01..2b33470a 100644
--- a/src/brainbase/_response.py
+++ b/src/brainbase/_response.py
@@ -233,7 +233,7 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T:
# split is required to handle cases where additional information is included
# in the response, e.g. application/json; charset=utf-8
content_type, *_ = response.headers.get("content-type", "*").split(";")
- if content_type != "application/json":
+ if not content_type.endswith("json"):
if is_basemodel(cast_to):
try:
data = response.json()
diff --git a/src/brainbase/_streaming.py b/src/brainbase/_streaming.py
index ec0b6267..d3bc9002 100644
--- a/src/brainbase/_streaming.py
+++ b/src/brainbase/_streaming.py
@@ -54,12 +54,12 @@ def __stream__(self) -> Iterator[_T]:
process_data = self._client._process_response_data
iterator = self._iter_events()
- for sse in iterator:
- yield process_data(data=sse.json(), cast_to=cast_to, response=response)
-
- # Ensure the entire stream is consumed
- for _sse in iterator:
- ...
+ try:
+ for sse in iterator:
+ yield process_data(data=sse.json(), cast_to=cast_to, response=response)
+ finally:
+ # Ensure the response is closed even if the consumer doesn't read all data
+ response.close()
def __enter__(self) -> Self:
return self
@@ -118,12 +118,12 @@ async def __stream__(self) -> AsyncIterator[_T]:
process_data = self._client._process_response_data
iterator = self._iter_events()
- async for sse in iterator:
- yield process_data(data=sse.json(), cast_to=cast_to, response=response)
-
- # Ensure the entire stream is consumed
- async for _sse in iterator:
- ...
+ try:
+ async for sse in iterator:
+ yield process_data(data=sse.json(), cast_to=cast_to, response=response)
+ finally:
+ # Ensure the response is closed even if the consumer doesn't read all data
+ await response.aclose()
async def __aenter__(self) -> Self:
return self
diff --git a/src/brainbase/_types.py b/src/brainbase/_types.py
index 1a11c5e5..03b1fc5a 100644
--- a/src/brainbase/_types.py
+++ b/src/brainbase/_types.py
@@ -13,10 +13,21 @@
Mapping,
TypeVar,
Callable,
+ Iterator,
Optional,
Sequence,
)
-from typing_extensions import Set, Literal, Protocol, TypeAlias, TypedDict, override, runtime_checkable
+from typing_extensions import (
+ Set,
+ Literal,
+ Protocol,
+ TypeAlias,
+ TypedDict,
+ SupportsIndex,
+ overload,
+ override,
+ runtime_checkable,
+)
import httpx
import pydantic
@@ -100,23 +111,27 @@ class RequestOptions(TypedDict, total=False):
params: Query
extra_json: AnyMapping
idempotency_key: str
+ follow_redirects: bool
# Sentinel class used until PEP 0661 is accepted
class NotGiven:
"""
- A sentinel singleton class used to distinguish omitted keyword arguments
- from those passed in with the value None (which may have different behavior).
+ For parameters with a meaningful None value, we need to distinguish between
+ the user explicitly passing None, and the user not passing the parameter at
+ all.
+
+ User code shouldn't need to use not_given directly.
For example:
```py
- def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ...
+ def create(timeout: Timeout | None | NotGiven = not_given): ...
- get(timeout=1) # 1s timeout
- get(timeout=None) # No timeout
- get() # Default timeout behavior, which may not be statically known at the method definition.
+ create(timeout=1) # 1s timeout
+ create(timeout=None) # No timeout
+ create() # Default timeout behavior
```
"""
@@ -128,13 +143,14 @@ def __repr__(self) -> str:
return "NOT_GIVEN"
-NotGivenOr = Union[_T, NotGiven]
+not_given = NotGiven()
+# for backwards compatibility:
NOT_GIVEN = NotGiven()
class Omit:
- """In certain situations you need to be able to represent a case where a default value has
- to be explicitly removed and `None` is not an appropriate substitute, for example:
+ """
+ To explicitly omit something from being sent in a request, use `omit`.
```py
# as the default `Content-Type` header is `application/json` that will be sent
@@ -144,8 +160,8 @@ class Omit:
# to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983'
client.post(..., headers={"Content-Type": "multipart/form-data"})
- # instead you can remove the default `application/json` header by passing Omit
- client.post(..., headers={"Content-Type": Omit()})
+ # instead you can remove the default `application/json` header by passing omit
+ client.post(..., headers={"Content-Type": omit})
```
"""
@@ -153,6 +169,9 @@ def __bool__(self) -> Literal[False]:
return False
+omit = Omit()
+
+
@runtime_checkable
class ModelBuilderProtocol(Protocol):
@classmethod
@@ -215,3 +234,28 @@ class _GenericAlias(Protocol):
class HttpxSendArgs(TypedDict, total=False):
auth: httpx.Auth
+ follow_redirects: bool
+
+
+_T_co = TypeVar("_T_co", covariant=True)
+
+
+if TYPE_CHECKING:
+ # This works because str.__contains__ does not accept object (either in typeshed or at runtime)
+ # https://github.com/hauntsaninja/useful_types/blob/5e9710f3875107d068e7679fd7fec9cfab0eff3b/useful_types/__init__.py#L285
+ #
+ # Note: index() and count() methods are intentionally omitted to allow pyright to properly
+ # infer TypedDict types when dict literals are used in lists assigned to SequenceNotStr.
+ class SequenceNotStr(Protocol[_T_co]):
+ @overload
+ def __getitem__(self, index: SupportsIndex, /) -> _T_co: ...
+ @overload
+ def __getitem__(self, index: slice, /) -> Sequence[_T_co]: ...
+ def __contains__(self, value: object, /) -> bool: ...
+ def __len__(self) -> int: ...
+ def __iter__(self) -> Iterator[_T_co]: ...
+ def __reversed__(self) -> Iterator[_T_co]: ...
+else:
+ # just point this to a normal `Sequence` at runtime to avoid having to special case
+ # deserializing our custom sequence type
+ SequenceNotStr = Sequence
diff --git a/src/brainbase/_utils/__init__.py b/src/brainbase/_utils/__init__.py
index d4fda26f..dc64e29a 100644
--- a/src/brainbase/_utils/__init__.py
+++ b/src/brainbase/_utils/__init__.py
@@ -10,7 +10,6 @@
lru_cache as lru_cache,
is_mapping as is_mapping,
is_tuple_t as is_tuple_t,
- parse_date as parse_date,
is_iterable as is_iterable,
is_sequence as is_sequence,
coerce_float as coerce_float,
@@ -23,7 +22,6 @@
coerce_boolean as coerce_boolean,
coerce_integer as coerce_integer,
file_from_path as file_from_path,
- parse_datetime as parse_datetime,
strip_not_given as strip_not_given,
deepcopy_minimal as deepcopy_minimal,
get_async_library as get_async_library,
@@ -32,12 +30,20 @@
maybe_coerce_boolean as maybe_coerce_boolean,
maybe_coerce_integer as maybe_coerce_integer,
)
+from ._compat import (
+ get_args as get_args,
+ is_union as is_union,
+ get_origin as get_origin,
+ is_typeddict as is_typeddict,
+ is_literal_type as is_literal_type,
+)
from ._typing import (
is_list_type as is_list_type,
is_union_type as is_union_type,
extract_type_arg as extract_type_arg,
is_iterable_type as is_iterable_type,
is_required_type as is_required_type,
+ is_sequence_type as is_sequence_type,
is_annotated_type as is_annotated_type,
is_type_alias_type as is_type_alias_type,
strip_annotated_type as strip_annotated_type,
@@ -55,3 +61,4 @@
function_has_argument as function_has_argument,
assert_signatures_in_sync as assert_signatures_in_sync,
)
+from ._datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime
diff --git a/src/brainbase/_utils/_compat.py b/src/brainbase/_utils/_compat.py
new file mode 100644
index 00000000..dd703233
--- /dev/null
+++ b/src/brainbase/_utils/_compat.py
@@ -0,0 +1,45 @@
+from __future__ import annotations
+
+import sys
+import typing_extensions
+from typing import Any, Type, Union, Literal, Optional
+from datetime import date, datetime
+from typing_extensions import get_args as _get_args, get_origin as _get_origin
+
+from .._types import StrBytesIntFloat
+from ._datetime_parse import parse_date as _parse_date, parse_datetime as _parse_datetime
+
+_LITERAL_TYPES = {Literal, typing_extensions.Literal}
+
+
+def get_args(tp: type[Any]) -> tuple[Any, ...]:
+ return _get_args(tp)
+
+
+def get_origin(tp: type[Any]) -> type[Any] | None:
+ return _get_origin(tp)
+
+
+def is_union(tp: Optional[Type[Any]]) -> bool:
+ if sys.version_info < (3, 10):
+ return tp is Union # type: ignore[comparison-overlap]
+ else:
+ import types
+
+ return tp is Union or tp is types.UnionType
+
+
+def is_typeddict(tp: Type[Any]) -> bool:
+ return typing_extensions.is_typeddict(tp)
+
+
+def is_literal_type(tp: Type[Any]) -> bool:
+ return get_origin(tp) in _LITERAL_TYPES
+
+
+def parse_date(value: Union[date, StrBytesIntFloat]) -> date:
+ return _parse_date(value)
+
+
+def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime:
+ return _parse_datetime(value)
diff --git a/src/brainbase/_utils/_datetime_parse.py b/src/brainbase/_utils/_datetime_parse.py
new file mode 100644
index 00000000..7cb9d9e6
--- /dev/null
+++ b/src/brainbase/_utils/_datetime_parse.py
@@ -0,0 +1,136 @@
+"""
+This file contains code from https://github.com/pydantic/pydantic/blob/main/pydantic/v1/datetime_parse.py
+without the Pydantic v1 specific errors.
+"""
+
+from __future__ import annotations
+
+import re
+from typing import Dict, Union, Optional
+from datetime import date, datetime, timezone, timedelta
+
+from .._types import StrBytesIntFloat
+
+date_expr = r"(?P\d{4})-(?P\d{1,2})-(?P\d{1,2})"
+time_expr = (
+ r"(?P\d{1,2}):(?P\d{1,2})"
+ r"(?::(?P\d{1,2})(?:\.(?P\d{1,6})\d{0,6})?)?"
+ r"(?PZ|[+-]\d{2}(?::?\d{2})?)?$"
+)
+
+date_re = re.compile(f"{date_expr}$")
+datetime_re = re.compile(f"{date_expr}[T ]{time_expr}")
+
+
+EPOCH = datetime(1970, 1, 1)
+# if greater than this, the number is in ms, if less than or equal it's in seconds
+# (in seconds this is 11th October 2603, in ms it's 20th August 1970)
+MS_WATERSHED = int(2e10)
+# slightly more than datetime.max in ns - (datetime.max - EPOCH).total_seconds() * 1e9
+MAX_NUMBER = int(3e20)
+
+
+def _get_numeric(value: StrBytesIntFloat, native_expected_type: str) -> Union[None, int, float]:
+ if isinstance(value, (int, float)):
+ return value
+ try:
+ return float(value)
+ except ValueError:
+ return None
+ except TypeError:
+ raise TypeError(f"invalid type; expected {native_expected_type}, string, bytes, int or float") from None
+
+
+def _from_unix_seconds(seconds: Union[int, float]) -> datetime:
+ if seconds > MAX_NUMBER:
+ return datetime.max
+ elif seconds < -MAX_NUMBER:
+ return datetime.min
+
+ while abs(seconds) > MS_WATERSHED:
+ seconds /= 1000
+ dt = EPOCH + timedelta(seconds=seconds)
+ return dt.replace(tzinfo=timezone.utc)
+
+
+def _parse_timezone(value: Optional[str]) -> Union[None, int, timezone]:
+ if value == "Z":
+ return timezone.utc
+ elif value is not None:
+ offset_mins = int(value[-2:]) if len(value) > 3 else 0
+ offset = 60 * int(value[1:3]) + offset_mins
+ if value[0] == "-":
+ offset = -offset
+ return timezone(timedelta(minutes=offset))
+ else:
+ return None
+
+
+def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime:
+ """
+ Parse a datetime/int/float/string and return a datetime.datetime.
+
+ This function supports time zone offsets. When the input contains one,
+ the output uses a timezone with a fixed offset from UTC.
+
+ Raise ValueError if the input is well formatted but not a valid datetime.
+ Raise ValueError if the input isn't well formatted.
+ """
+ if isinstance(value, datetime):
+ return value
+
+ number = _get_numeric(value, "datetime")
+ if number is not None:
+ return _from_unix_seconds(number)
+
+ if isinstance(value, bytes):
+ value = value.decode()
+
+ assert not isinstance(value, (float, int))
+
+ match = datetime_re.match(value)
+ if match is None:
+ raise ValueError("invalid datetime format")
+
+ kw = match.groupdict()
+ if kw["microsecond"]:
+ kw["microsecond"] = kw["microsecond"].ljust(6, "0")
+
+ tzinfo = _parse_timezone(kw.pop("tzinfo"))
+ kw_: Dict[str, Union[None, int, timezone]] = {k: int(v) for k, v in kw.items() if v is not None}
+ kw_["tzinfo"] = tzinfo
+
+ return datetime(**kw_) # type: ignore
+
+
+def parse_date(value: Union[date, StrBytesIntFloat]) -> date:
+ """
+ Parse a date/int/float/string and return a datetime.date.
+
+ Raise ValueError if the input is well formatted but not a valid date.
+ Raise ValueError if the input isn't well formatted.
+ """
+ if isinstance(value, date):
+ if isinstance(value, datetime):
+ return value.date()
+ else:
+ return value
+
+ number = _get_numeric(value, "date")
+ if number is not None:
+ return _from_unix_seconds(number).date()
+
+ if isinstance(value, bytes):
+ value = value.decode()
+
+ assert not isinstance(value, (float, int))
+ match = date_re.match(value)
+ if match is None:
+ raise ValueError("invalid date format")
+
+ kw = {k: int(v) for k, v in match.groupdict().items()}
+
+ try:
+ return date(**kw)
+ except ValueError:
+ raise ValueError("invalid date format") from None
diff --git a/src/brainbase/_utils/_proxy.py b/src/brainbase/_utils/_proxy.py
index ffd883e9..0f239a33 100644
--- a/src/brainbase/_utils/_proxy.py
+++ b/src/brainbase/_utils/_proxy.py
@@ -46,7 +46,10 @@ def __dir__(self) -> Iterable[str]:
@property # type: ignore
@override
def __class__(self) -> type: # pyright: ignore
- proxied = self.__get_proxied__()
+ try:
+ proxied = self.__get_proxied__()
+ except Exception:
+ return type(self)
if issubclass(type(proxied), LazyProxy):
return type(proxied)
return proxied.__class__
diff --git a/src/brainbase/_utils/_resources_proxy.py b/src/brainbase/_utils/_resources_proxy.py
new file mode 100644
index 00000000..1ba107b7
--- /dev/null
+++ b/src/brainbase/_utils/_resources_proxy.py
@@ -0,0 +1,24 @@
+from __future__ import annotations
+
+from typing import Any
+from typing_extensions import override
+
+from ._proxy import LazyProxy
+
+
+class ResourcesProxy(LazyProxy[Any]):
+ """A proxy for the `brainbase.resources` module.
+
+ This is used so that we can lazily import `brainbase.resources` only when
+ needed *and* so that users can just import `brainbase` and reference `brainbase.resources`
+ """
+
+ @override
+ def __load__(self) -> Any:
+ import importlib
+
+ mod = importlib.import_module("brainbase.resources")
+ return mod
+
+
+resources = ResourcesProxy().__as_proxied__()
diff --git a/src/brainbase/_utils/_sync.py b/src/brainbase/_utils/_sync.py
index 8b3aaf2b..f6027c18 100644
--- a/src/brainbase/_utils/_sync.py
+++ b/src/brainbase/_utils/_sync.py
@@ -1,47 +1,34 @@
from __future__ import annotations
-import sys
import asyncio
import functools
-import contextvars
-from typing import Any, TypeVar, Callable, Awaitable
+from typing import TypeVar, Callable, Awaitable
from typing_extensions import ParamSpec
+import anyio
+import sniffio
+import anyio.to_thread
+
T_Retval = TypeVar("T_Retval")
T_ParamSpec = ParamSpec("T_ParamSpec")
-if sys.version_info >= (3, 9):
- to_thread = asyncio.to_thread
-else:
- # backport of https://docs.python.org/3/library/asyncio-task.html#asyncio.to_thread
- # for Python 3.8 support
- async def to_thread(
- func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs
- ) -> Any:
- """Asynchronously run function *func* in a separate thread.
-
- Any *args and **kwargs supplied for this function are directly passed
- to *func*. Also, the current :class:`contextvars.Context` is propagated,
- allowing context variables from the main thread to be accessed in the
- separate thread.
+async def to_thread(
+ func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs
+) -> T_Retval:
+ if sniffio.current_async_library() == "asyncio":
+ return await asyncio.to_thread(func, *args, **kwargs)
- Returns a coroutine that can be awaited to get the eventual result of *func*.
- """
- loop = asyncio.events.get_running_loop()
- ctx = contextvars.copy_context()
- func_call = functools.partial(ctx.run, func, *args, **kwargs)
- return await loop.run_in_executor(None, func_call)
+ return await anyio.to_thread.run_sync(
+ functools.partial(func, *args, **kwargs),
+ )
# inspired by `asyncer`, https://github.com/tiangolo/asyncer
def asyncify(function: Callable[T_ParamSpec, T_Retval]) -> Callable[T_ParamSpec, Awaitable[T_Retval]]:
"""
Take a blocking function and create an async one that receives the same
- positional and keyword arguments. For python version 3.9 and above, it uses
- asyncio.to_thread to run the function in a separate thread. For python version
- 3.8, it uses locally defined copy of the asyncio.to_thread function which was
- introduced in python 3.9.
+ positional and keyword arguments.
Usage:
diff --git a/src/brainbase/_utils/_transform.py b/src/brainbase/_utils/_transform.py
index a6b62cad..52075492 100644
--- a/src/brainbase/_utils/_transform.py
+++ b/src/brainbase/_utils/_transform.py
@@ -5,27 +5,31 @@
import pathlib
from typing import Any, Mapping, TypeVar, cast
from datetime import date, datetime
-from typing_extensions import Literal, get_args, override, get_type_hints
+from typing_extensions import Literal, get_args, override, get_type_hints as _get_type_hints
import anyio
import pydantic
from ._utils import (
is_list,
+ is_given,
+ lru_cache,
is_mapping,
is_iterable,
+ is_sequence,
)
from .._files import is_base64_file_input
+from ._compat import get_origin, is_typeddict
from ._typing import (
is_list_type,
is_union_type,
extract_type_arg,
is_iterable_type,
is_required_type,
+ is_sequence_type,
is_annotated_type,
strip_annotated_type,
)
-from .._compat import model_dump, is_typeddict
_T = TypeVar("_T")
@@ -108,6 +112,7 @@ class Params(TypedDict, total=False):
return cast(_T, transformed)
+@lru_cache(maxsize=8096)
def _get_annotated_type(type_: type) -> type | None:
"""If the given type is an `Annotated` type then it is returned, if not `None` is returned.
@@ -126,7 +131,7 @@ def _get_annotated_type(type_: type) -> type | None:
def _maybe_transform_key(key: str, type_: type) -> str:
"""Transform the given `data` based on the annotations provided in `type_`.
- Note: this function only looks at `Annotated` types that contain `PropertInfo` metadata.
+ Note: this function only looks at `Annotated` types that contain `PropertyInfo` metadata.
"""
annotated_type = _get_annotated_type(type_)
if annotated_type is None:
@@ -142,6 +147,10 @@ def _maybe_transform_key(key: str, type_: type) -> str:
return key
+def _no_transform_needed(annotation: type) -> bool:
+ return annotation == float or annotation == int
+
+
def _transform_recursive(
data: object,
*,
@@ -160,18 +169,27 @@ def _transform_recursive(
Defaults to the same value as the `annotation` argument.
"""
+ from .._compat import model_dump
+
if inner_type is None:
inner_type = annotation
stripped_type = strip_annotated_type(inner_type)
+ origin = get_origin(stripped_type) or stripped_type
if is_typeddict(stripped_type) and is_mapping(data):
return _transform_typeddict(data, stripped_type)
+ if origin == dict and is_mapping(data):
+ items_type = get_args(stripped_type)[1]
+ return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()}
+
if (
# List[T]
(is_list_type(stripped_type) and is_list(data))
# Iterable[T]
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
+ # Sequence[T]
+ or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str))
):
# dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually
# intended as an iterable, so we don't transform it.
@@ -179,6 +197,15 @@ def _transform_recursive(
return cast(object, data)
inner_type = extract_type_arg(stripped_type, 0)
+ if _no_transform_needed(inner_type):
+ # for some types there is no need to transform anything, so we can get a small
+ # perf boost from skipping that work.
+ #
+ # but we still need to convert to a list to ensure the data is json-serializable
+ if is_list(data):
+ return data
+ return list(data)
+
return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
if is_union_type(stripped_type):
@@ -240,6 +267,11 @@ def _transform_typeddict(
result: dict[str, object] = {}
annotations = get_type_hints(expected_type, include_extras=True)
for key, value in data.items():
+ if not is_given(value):
+ # we don't need to include omitted values here as they'll
+ # be stripped out before the request is sent anyway
+ continue
+
type_ = annotations.get(key)
if type_ is None:
# we do not have a type annotation for this field, leave it as is
@@ -303,18 +335,27 @@ async def _async_transform_recursive(
Defaults to the same value as the `annotation` argument.
"""
+ from .._compat import model_dump
+
if inner_type is None:
inner_type = annotation
stripped_type = strip_annotated_type(inner_type)
+ origin = get_origin(stripped_type) or stripped_type
if is_typeddict(stripped_type) and is_mapping(data):
return await _async_transform_typeddict(data, stripped_type)
+ if origin == dict and is_mapping(data):
+ items_type = get_args(stripped_type)[1]
+ return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()}
+
if (
# List[T]
(is_list_type(stripped_type) and is_list(data))
# Iterable[T]
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
+ # Sequence[T]
+ or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str))
):
# dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually
# intended as an iterable, so we don't transform it.
@@ -322,6 +363,15 @@ async def _async_transform_recursive(
return cast(object, data)
inner_type = extract_type_arg(stripped_type, 0)
+ if _no_transform_needed(inner_type):
+ # for some types there is no need to transform anything, so we can get a small
+ # perf boost from skipping that work.
+ #
+ # but we still need to convert to a list to ensure the data is json-serializable
+ if is_list(data):
+ return data
+ return list(data)
+
return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
if is_union_type(stripped_type):
@@ -383,6 +433,11 @@ async def _async_transform_typeddict(
result: dict[str, object] = {}
annotations = get_type_hints(expected_type, include_extras=True)
for key, value in data.items():
+ if not is_given(value):
+ # we don't need to include omitted values here as they'll
+ # be stripped out before the request is sent anyway
+ continue
+
type_ = annotations.get(key)
if type_ is None:
# we do not have a type annotation for this field, leave it as is
@@ -390,3 +445,13 @@ async def _async_transform_typeddict(
else:
result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_)
return result
+
+
+@lru_cache(maxsize=8096)
+def get_type_hints(
+ obj: Any,
+ globalns: dict[str, Any] | None = None,
+ localns: Mapping[str, Any] | None = None,
+ include_extras: bool = False,
+) -> dict[str, Any]:
+ return _get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras)
diff --git a/src/brainbase/_utils/_typing.py b/src/brainbase/_utils/_typing.py
index 278749b1..193109f3 100644
--- a/src/brainbase/_utils/_typing.py
+++ b/src/brainbase/_utils/_typing.py
@@ -13,8 +13,9 @@
get_origin,
)
+from ._utils import lru_cache
from .._types import InheritsGeneric
-from .._compat import is_union as _is_union
+from ._compat import is_union as _is_union
def is_annotated_type(typ: type) -> bool:
@@ -25,6 +26,11 @@ def is_list_type(typ: type) -> bool:
return (get_origin(typ) or typ) == list
+def is_sequence_type(typ: type) -> bool:
+ origin = get_origin(typ) or typ
+ return origin == typing_extensions.Sequence or origin == typing.Sequence or origin == _c_abc.Sequence
+
+
def is_iterable_type(typ: type) -> bool:
"""If the given type is `typing.Iterable[T]`"""
origin = get_origin(typ) or typ
@@ -66,6 +72,7 @@ def is_type_alias_type(tp: Any, /) -> TypeIs[typing_extensions.TypeAliasType]:
# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
+@lru_cache(maxsize=8096)
def strip_annotated_type(typ: type) -> type:
if is_required_type(typ) or is_annotated_type(typ):
return strip_annotated_type(cast(type, get_args(typ)[0]))
@@ -108,7 +115,7 @@ class MyResponse(Foo[_T]):
```
"""
cls = cast(object, get_origin(typ) or typ)
- if cls in generic_bases:
+ if cls in generic_bases: # pyright: ignore[reportUnnecessaryContains]
# we're given the class directly
return extract_type_arg(typ, index)
diff --git a/src/brainbase/_utils/_utils.py b/src/brainbase/_utils/_utils.py
index e5811bba..eec7f4a1 100644
--- a/src/brainbase/_utils/_utils.py
+++ b/src/brainbase/_utils/_utils.py
@@ -21,8 +21,7 @@
import sniffio
-from .._types import NotGiven, FileTypes, NotGivenOr, HeadersLike
-from .._compat import parse_date as parse_date, parse_datetime as parse_datetime
+from .._types import Omit, NotGiven, FileTypes, HeadersLike
_T = TypeVar("_T")
_TupleT = TypeVar("_TupleT", bound=Tuple[object, ...])
@@ -64,7 +63,7 @@ def _extract_items(
try:
key = path[index]
except IndexError:
- if isinstance(obj, NotGiven):
+ if not is_given(obj):
# no value was provided - we can safely ignore
return []
@@ -72,8 +71,16 @@ def _extract_items(
from .._files import assert_is_file_content
# We have exhausted the path, return the entry we found.
- assert_is_file_content(obj, key=flattened_key)
assert flattened_key is not None
+
+ if is_list(obj):
+ files: list[tuple[str, FileTypes]] = []
+ for entry in obj:
+ assert_is_file_content(entry, key=flattened_key + "[]" if flattened_key else "")
+ files.append((flattened_key + "[]", cast(FileTypes, entry)))
+ return files
+
+ assert_is_file_content(obj, key=flattened_key)
return [(flattened_key, cast(FileTypes, obj))]
index += 1
@@ -119,14 +126,14 @@ def _extract_items(
return []
-def is_given(obj: NotGivenOr[_T]) -> TypeGuard[_T]:
- return not isinstance(obj, NotGiven)
+def is_given(obj: _T | NotGiven | Omit) -> TypeGuard[_T]:
+ return not isinstance(obj, NotGiven) and not isinstance(obj, Omit)
# Type safe methods for narrowing types with TypeVars.
# The default narrowing for isinstance(obj, dict) is dict[unknown, unknown],
# however this cause Pyright to rightfully report errors. As we know we don't
-# care about the contained types we can safely use `object` in it's place.
+# care about the contained types we can safely use `object` in its place.
#
# There are two separate functions defined, `is_*` and `is_*_t` for different use cases.
# `is_*` is for when you're dealing with an unknown input
diff --git a/src/brainbase/_version.py b/src/brainbase/_version.py
index 00b5ae6a..791d7b56 100644
--- a/src/brainbase/_version.py
+++ b/src/brainbase/_version.py
@@ -1,4 +1,4 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
__title__ = "brainbase"
-__version__ = "4.0.0" # x-release-please-version
+__version__ = "4.1.0" # x-release-please-version
diff --git a/src/brainbase/resources/workers/deployments/voice.py b/src/brainbase/resources/workers/deployments/voice.py
index 241c6123..f0e101c3 100644
--- a/src/brainbase/resources/workers/deployments/voice.py
+++ b/src/brainbase/resources/workers/deployments/voice.py
@@ -4,11 +4,8 @@
import httpx
-from ...._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven
-from ...._utils import (
- maybe_transform,
- async_maybe_transform,
-)
+from ...._types import Body, Omit, Query, Headers, NoneType, NotGiven, omit, not_given
+from ...._utils import maybe_transform, async_maybe_transform
from ...._compat import cached_property
from ...._resource import SyncAPIResource, AsyncAPIResource
from ...._response import (
@@ -19,10 +16,8 @@
)
from ...._base_client import make_request_options
from ....types.workers.deployments import voice_create_params, voice_update_params
+from ....types.workers.deployments.voice_deployment import VoiceDeployment
from ....types.workers.deployments.voice_list_response import VoiceListResponse
-from ....types.workers.deployments.voice_create_response import VoiceCreateResponse
-from ....types.workers.deployments.voice_update_response import VoiceUpdateResponse
-from ....types.workers.deployments.voice_retrieve_response import VoiceRetrieveResponse
__all__ = ["VoiceResource", "AsyncVoiceResource"]
@@ -52,16 +47,16 @@ def create(
worker_id: str,
*,
name: str,
- phone_number: str | NotGiven = NOT_GIVEN,
- voice_id: str | NotGiven = NOT_GIVEN,
- voice_provider: str | NotGiven = NOT_GIVEN,
+ phone_number: str | Omit = omit,
+ voice_id: str | Omit = omit,
+ voice_provider: str | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> VoiceCreateResponse:
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
+ ) -> VoiceDeployment:
"""
Create a new voice deployment
@@ -98,7 +93,7 @@ def create(
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
- cast_to=VoiceCreateResponse,
+ cast_to=VoiceDeployment,
)
def retrieve(
@@ -111,8 +106,8 @@ def retrieve(
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> VoiceRetrieveResponse:
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
+ ) -> VoiceDeployment:
"""
Get a single voice deployment
@@ -134,7 +129,7 @@ def retrieve(
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
- cast_to=VoiceRetrieveResponse,
+ cast_to=VoiceDeployment,
)
def update(
@@ -143,16 +138,16 @@ def update(
*,
worker_id: str,
name: str,
- phone_number: str | NotGiven = NOT_GIVEN,
- voice_id: str | NotGiven = NOT_GIVEN,
- voice_provider: str | NotGiven = NOT_GIVEN,
+ phone_number: str | Omit = omit,
+ voice_id: str | Omit = omit,
+ voice_provider: str | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> VoiceUpdateResponse:
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
+ ) -> VoiceDeployment:
"""
Update a voice deployment
@@ -191,7 +186,7 @@ def update(
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
- cast_to=VoiceUpdateResponse,
+ cast_to=VoiceDeployment,
)
def list(
@@ -203,7 +198,7 @@ def list(
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> VoiceListResponse:
"""
Get all voice deployments for a worker
@@ -237,7 +232,7 @@ def delete(
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> None:
"""
Delete a voice deployment
@@ -290,16 +285,16 @@ async def create(
worker_id: str,
*,
name: str,
- phone_number: str | NotGiven = NOT_GIVEN,
- voice_id: str | NotGiven = NOT_GIVEN,
- voice_provider: str | NotGiven = NOT_GIVEN,
+ phone_number: str | Omit = omit,
+ voice_id: str | Omit = omit,
+ voice_provider: str | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> VoiceCreateResponse:
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
+ ) -> VoiceDeployment:
"""
Create a new voice deployment
@@ -336,7 +331,7 @@ async def create(
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
- cast_to=VoiceCreateResponse,
+ cast_to=VoiceDeployment,
)
async def retrieve(
@@ -349,8 +344,8 @@ async def retrieve(
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> VoiceRetrieveResponse:
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
+ ) -> VoiceDeployment:
"""
Get a single voice deployment
@@ -372,7 +367,7 @@ async def retrieve(
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
- cast_to=VoiceRetrieveResponse,
+ cast_to=VoiceDeployment,
)
async def update(
@@ -381,16 +376,16 @@ async def update(
*,
worker_id: str,
name: str,
- phone_number: str | NotGiven = NOT_GIVEN,
- voice_id: str | NotGiven = NOT_GIVEN,
- voice_provider: str | NotGiven = NOT_GIVEN,
+ phone_number: str | Omit = omit,
+ voice_id: str | Omit = omit,
+ voice_provider: str | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> VoiceUpdateResponse:
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
+ ) -> VoiceDeployment:
"""
Update a voice deployment
@@ -429,7 +424,7 @@ async def update(
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
- cast_to=VoiceUpdateResponse,
+ cast_to=VoiceDeployment,
)
async def list(
@@ -441,7 +436,7 @@ async def list(
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> VoiceListResponse:
"""
Get all voice deployments for a worker
@@ -475,7 +470,7 @@ async def delete(
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> None:
"""
Delete a voice deployment
diff --git a/src/brainbase/resources/workers/flows.py b/src/brainbase/resources/workers/flows.py
index 0d2ef89f..7bb20bc7 100644
--- a/src/brainbase/resources/workers/flows.py
+++ b/src/brainbase/resources/workers/flows.py
@@ -4,11 +4,8 @@
import httpx
-from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven
-from ..._utils import (
- maybe_transform,
- async_maybe_transform,
-)
+from ..._types import Body, Omit, Query, Headers, NoneType, NotGiven, omit, not_given
+from ..._utils import maybe_transform, async_maybe_transform
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
from ..._response import (
@@ -53,13 +50,13 @@ def create(
*,
code: str,
name: str,
- label: str | NotGiven = NOT_GIVEN,
+ label: str | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> FlowCreateResponse:
"""
Create a new flow
@@ -107,7 +104,7 @@ def retrieve(
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> FlowRetrieveResponse:
"""
Get a single flow
@@ -138,15 +135,15 @@ def update(
flow_id: str,
*,
worker_id: str,
- code: str | NotGiven = NOT_GIVEN,
- label: str | NotGiven = NOT_GIVEN,
- name: str | NotGiven = NOT_GIVEN,
+ code: str | Omit = omit,
+ label: str | Omit = omit,
+ name: str | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> FlowUpdateResponse:
"""
Update a flow
@@ -195,7 +192,7 @@ def list(
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> FlowListResponse:
"""
Get all flows for a worker
@@ -229,7 +226,7 @@ def delete(
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> None:
"""
Delete a flow
@@ -283,13 +280,13 @@ async def create(
*,
code: str,
name: str,
- label: str | NotGiven = NOT_GIVEN,
+ label: str | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> FlowCreateResponse:
"""
Create a new flow
@@ -337,7 +334,7 @@ async def retrieve(
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> FlowRetrieveResponse:
"""
Get a single flow
@@ -368,15 +365,15 @@ async def update(
flow_id: str,
*,
worker_id: str,
- code: str | NotGiven = NOT_GIVEN,
- label: str | NotGiven = NOT_GIVEN,
- name: str | NotGiven = NOT_GIVEN,
+ code: str | Omit = omit,
+ label: str | Omit = omit,
+ name: str | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> FlowUpdateResponse:
"""
Update a flow
@@ -425,7 +422,7 @@ async def list(
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> FlowListResponse:
"""
Get all flows for a worker
@@ -459,7 +456,7 @@ async def delete(
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> None:
"""
Delete a flow
diff --git a/src/brainbase/resources/workers/workers.py b/src/brainbase/resources/workers/workers.py
index 9c160fbb..30d1f790 100644
--- a/src/brainbase/resources/workers/workers.py
+++ b/src/brainbase/resources/workers/workers.py
@@ -13,11 +13,8 @@
AsyncFlowsResourceWithStreamingResponse,
)
from ...types import worker_create_params, worker_update_params
-from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven
-from ..._utils import (
- maybe_transform,
- async_maybe_transform,
-)
+from ..._types import Body, Omit, Query, Headers, NoneType, NotGiven, omit, not_given
+from ..._utils import maybe_transform, async_maybe_transform
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
from ..._response import (
@@ -75,13 +72,13 @@ def create(
self,
*,
name: str,
- description: str | NotGiven = NOT_GIVEN,
+ description: str | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> WorkerCreateResponse:
"""
Create a new worker
@@ -123,7 +120,7 @@ def retrieve(
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> WorkerRetrieveResponse:
"""
Get a single worker
@@ -151,14 +148,14 @@ def update(
self,
id: str,
*,
- description: str | NotGiven = NOT_GIVEN,
- name: str | NotGiven = NOT_GIVEN,
+ description: str | Omit = omit,
+ name: str | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> WorkerUpdateResponse:
"""
Update a worker
@@ -201,7 +198,7 @@ def list(
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> WorkerListResponse:
"""Get all workers for the team"""
return self._get(
@@ -221,7 +218,7 @@ def delete(
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> None:
"""
Delete a worker
@@ -279,13 +276,13 @@ async def create(
self,
*,
name: str,
- description: str | NotGiven = NOT_GIVEN,
+ description: str | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> WorkerCreateResponse:
"""
Create a new worker
@@ -327,7 +324,7 @@ async def retrieve(
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> WorkerRetrieveResponse:
"""
Get a single worker
@@ -355,14 +352,14 @@ async def update(
self,
id: str,
*,
- description: str | NotGiven = NOT_GIVEN,
- name: str | NotGiven = NOT_GIVEN,
+ description: str | Omit = omit,
+ name: str | Omit = omit,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> WorkerUpdateResponse:
"""
Update a worker
@@ -405,7 +402,7 @@ async def list(
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> WorkerListResponse:
"""Get all workers for the team"""
return await self._get(
@@ -425,7 +422,7 @@ async def delete(
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> None:
"""
Delete a worker
diff --git a/src/brainbase/types/workers/deployments/__init__.py b/src/brainbase/types/workers/deployments/__init__.py
index dbdb876a..e5947326 100644
--- a/src/brainbase/types/workers/deployments/__init__.py
+++ b/src/brainbase/types/workers/deployments/__init__.py
@@ -2,9 +2,7 @@
from __future__ import annotations
+from .voice_deployment import VoiceDeployment as VoiceDeployment
from .voice_create_params import VoiceCreateParams as VoiceCreateParams
from .voice_list_response import VoiceListResponse as VoiceListResponse
from .voice_update_params import VoiceUpdateParams as VoiceUpdateParams
-from .voice_create_response import VoiceCreateResponse as VoiceCreateResponse
-from .voice_update_response import VoiceUpdateResponse as VoiceUpdateResponse
-from .voice_retrieve_response import VoiceRetrieveResponse as VoiceRetrieveResponse
diff --git a/src/brainbase/types/workers/deployments/voice_create_response.py b/src/brainbase/types/workers/deployments/voice_deployment.py
similarity index 88%
rename from src/brainbase/types/workers/deployments/voice_create_response.py
rename to src/brainbase/types/workers/deployments/voice_deployment.py
index 2c9a7960..61ea6936 100644
--- a/src/brainbase/types/workers/deployments/voice_create_response.py
+++ b/src/brainbase/types/workers/deployments/voice_deployment.py
@@ -6,10 +6,10 @@
from ...._models import BaseModel
-__all__ = ["VoiceCreateResponse"]
+__all__ = ["VoiceDeployment"]
-class VoiceCreateResponse(BaseModel):
+class VoiceDeployment(BaseModel):
id: str
delegate_aux_deployments_id: Optional[str] = FieldInfo(alias="delegate_aux_deploymentsId", default=None)
diff --git a/src/brainbase/types/workers/deployments/voice_list_response.py b/src/brainbase/types/workers/deployments/voice_list_response.py
index d54f167d..d714d5dc 100644
--- a/src/brainbase/types/workers/deployments/voice_list_response.py
+++ b/src/brainbase/types/workers/deployments/voice_list_response.py
@@ -1,25 +1,10 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
-from typing import List, Optional
+from typing import List
from typing_extensions import TypeAlias
-from pydantic import Field as FieldInfo
+from .voice_deployment import VoiceDeployment
-from ...._models import BaseModel
+__all__ = ["VoiceListResponse"]
-__all__ = ["VoiceListResponse", "VoiceListResponseItem"]
-
-
-class VoiceListResponseItem(BaseModel):
- id: str
-
- delegate_aux_deployments_id: Optional[str] = FieldInfo(alias="delegate_aux_deploymentsId", default=None)
-
- phone_number: Optional[str] = FieldInfo(alias="phoneNumber", default=None)
-
- voice_id: Optional[str] = FieldInfo(alias="voiceId", default=None)
-
- voice_provider: Optional[str] = FieldInfo(alias="voiceProvider", default=None)
-
-
-VoiceListResponse: TypeAlias = List[VoiceListResponseItem]
+VoiceListResponse: TypeAlias = List[VoiceDeployment]
diff --git a/src/brainbase/types/workers/deployments/voice_retrieve_response.py b/src/brainbase/types/workers/deployments/voice_retrieve_response.py
deleted file mode 100644
index 15fc4abe..00000000
--- a/src/brainbase/types/workers/deployments/voice_retrieve_response.py
+++ /dev/null
@@ -1,21 +0,0 @@
-# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
-
-from typing import Optional
-
-from pydantic import Field as FieldInfo
-
-from ...._models import BaseModel
-
-__all__ = ["VoiceRetrieveResponse"]
-
-
-class VoiceRetrieveResponse(BaseModel):
- id: str
-
- delegate_aux_deployments_id: Optional[str] = FieldInfo(alias="delegate_aux_deploymentsId", default=None)
-
- phone_number: Optional[str] = FieldInfo(alias="phoneNumber", default=None)
-
- voice_id: Optional[str] = FieldInfo(alias="voiceId", default=None)
-
- voice_provider: Optional[str] = FieldInfo(alias="voiceProvider", default=None)
diff --git a/src/brainbase/types/workers/deployments/voice_update_response.py b/src/brainbase/types/workers/deployments/voice_update_response.py
deleted file mode 100644
index 67d8460f..00000000
--- a/src/brainbase/types/workers/deployments/voice_update_response.py
+++ /dev/null
@@ -1,21 +0,0 @@
-# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
-
-from typing import Optional
-
-from pydantic import Field as FieldInfo
-
-from ...._models import BaseModel
-
-__all__ = ["VoiceUpdateResponse"]
-
-
-class VoiceUpdateResponse(BaseModel):
- id: str
-
- delegate_aux_deployments_id: Optional[str] = FieldInfo(alias="delegate_aux_deploymentsId", default=None)
-
- phone_number: Optional[str] = FieldInfo(alias="phoneNumber", default=None)
-
- voice_id: Optional[str] = FieldInfo(alias="voiceId", default=None)
-
- voice_provider: Optional[str] = FieldInfo(alias="voiceProvider", default=None)
diff --git a/tests/api_resources/test_workers.py b/tests/api_resources/test_workers.py
index 28d15c32..2ae82aa6 100644
--- a/tests/api_resources/test_workers.py
+++ b/tests/api_resources/test_workers.py
@@ -22,7 +22,7 @@
class TestWorkers:
parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_method_create(self, client: Brainbase) -> None:
worker = client.workers.create(
@@ -30,7 +30,7 @@ def test_method_create(self, client: Brainbase) -> None:
)
assert_matches_type(WorkerCreateResponse, worker, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_method_create_with_all_params(self, client: Brainbase) -> None:
worker = client.workers.create(
@@ -39,7 +39,7 @@ def test_method_create_with_all_params(self, client: Brainbase) -> None:
)
assert_matches_type(WorkerCreateResponse, worker, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_raw_response_create(self, client: Brainbase) -> None:
response = client.workers.with_raw_response.create(
@@ -51,7 +51,7 @@ def test_raw_response_create(self, client: Brainbase) -> None:
worker = response.parse()
assert_matches_type(WorkerCreateResponse, worker, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_streaming_response_create(self, client: Brainbase) -> None:
with client.workers.with_streaming_response.create(
@@ -65,7 +65,7 @@ def test_streaming_response_create(self, client: Brainbase) -> None:
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_method_retrieve(self, client: Brainbase) -> None:
worker = client.workers.retrieve(
@@ -73,7 +73,7 @@ def test_method_retrieve(self, client: Brainbase) -> None:
)
assert_matches_type(WorkerRetrieveResponse, worker, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_raw_response_retrieve(self, client: Brainbase) -> None:
response = client.workers.with_raw_response.retrieve(
@@ -85,7 +85,7 @@ def test_raw_response_retrieve(self, client: Brainbase) -> None:
worker = response.parse()
assert_matches_type(WorkerRetrieveResponse, worker, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_streaming_response_retrieve(self, client: Brainbase) -> None:
with client.workers.with_streaming_response.retrieve(
@@ -99,7 +99,7 @@ def test_streaming_response_retrieve(self, client: Brainbase) -> None:
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_path_params_retrieve(self, client: Brainbase) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"):
@@ -107,7 +107,7 @@ def test_path_params_retrieve(self, client: Brainbase) -> None:
"",
)
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_method_update(self, client: Brainbase) -> None:
worker = client.workers.update(
@@ -115,7 +115,7 @@ def test_method_update(self, client: Brainbase) -> None:
)
assert_matches_type(WorkerUpdateResponse, worker, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_method_update_with_all_params(self, client: Brainbase) -> None:
worker = client.workers.update(
@@ -125,7 +125,7 @@ def test_method_update_with_all_params(self, client: Brainbase) -> None:
)
assert_matches_type(WorkerUpdateResponse, worker, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_raw_response_update(self, client: Brainbase) -> None:
response = client.workers.with_raw_response.update(
@@ -137,7 +137,7 @@ def test_raw_response_update(self, client: Brainbase) -> None:
worker = response.parse()
assert_matches_type(WorkerUpdateResponse, worker, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_streaming_response_update(self, client: Brainbase) -> None:
with client.workers.with_streaming_response.update(
@@ -151,7 +151,7 @@ def test_streaming_response_update(self, client: Brainbase) -> None:
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_path_params_update(self, client: Brainbase) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"):
@@ -159,13 +159,13 @@ def test_path_params_update(self, client: Brainbase) -> None:
id="",
)
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_method_list(self, client: Brainbase) -> None:
worker = client.workers.list()
assert_matches_type(WorkerListResponse, worker, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_raw_response_list(self, client: Brainbase) -> None:
response = client.workers.with_raw_response.list()
@@ -175,7 +175,7 @@ def test_raw_response_list(self, client: Brainbase) -> None:
worker = response.parse()
assert_matches_type(WorkerListResponse, worker, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_streaming_response_list(self, client: Brainbase) -> None:
with client.workers.with_streaming_response.list() as response:
@@ -187,7 +187,7 @@ def test_streaming_response_list(self, client: Brainbase) -> None:
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_method_delete(self, client: Brainbase) -> None:
worker = client.workers.delete(
@@ -195,7 +195,7 @@ def test_method_delete(self, client: Brainbase) -> None:
)
assert worker is None
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_raw_response_delete(self, client: Brainbase) -> None:
response = client.workers.with_raw_response.delete(
@@ -207,7 +207,7 @@ def test_raw_response_delete(self, client: Brainbase) -> None:
worker = response.parse()
assert worker is None
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_streaming_response_delete(self, client: Brainbase) -> None:
with client.workers.with_streaming_response.delete(
@@ -221,7 +221,7 @@ def test_streaming_response_delete(self, client: Brainbase) -> None:
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_path_params_delete(self, client: Brainbase) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"):
@@ -231,9 +231,11 @@ def test_path_params_delete(self, client: Brainbase) -> None:
class TestAsyncWorkers:
- parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
+ parametrize = pytest.mark.parametrize(
+ "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"]
+ )
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_method_create(self, async_client: AsyncBrainbase) -> None:
worker = await async_client.workers.create(
@@ -241,7 +243,7 @@ async def test_method_create(self, async_client: AsyncBrainbase) -> None:
)
assert_matches_type(WorkerCreateResponse, worker, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_method_create_with_all_params(self, async_client: AsyncBrainbase) -> None:
worker = await async_client.workers.create(
@@ -250,7 +252,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncBrainbase)
)
assert_matches_type(WorkerCreateResponse, worker, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_raw_response_create(self, async_client: AsyncBrainbase) -> None:
response = await async_client.workers.with_raw_response.create(
@@ -262,7 +264,7 @@ async def test_raw_response_create(self, async_client: AsyncBrainbase) -> None:
worker = await response.parse()
assert_matches_type(WorkerCreateResponse, worker, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_streaming_response_create(self, async_client: AsyncBrainbase) -> None:
async with async_client.workers.with_streaming_response.create(
@@ -276,7 +278,7 @@ async def test_streaming_response_create(self, async_client: AsyncBrainbase) ->
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_method_retrieve(self, async_client: AsyncBrainbase) -> None:
worker = await async_client.workers.retrieve(
@@ -284,7 +286,7 @@ async def test_method_retrieve(self, async_client: AsyncBrainbase) -> None:
)
assert_matches_type(WorkerRetrieveResponse, worker, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_raw_response_retrieve(self, async_client: AsyncBrainbase) -> None:
response = await async_client.workers.with_raw_response.retrieve(
@@ -296,7 +298,7 @@ async def test_raw_response_retrieve(self, async_client: AsyncBrainbase) -> None
worker = await response.parse()
assert_matches_type(WorkerRetrieveResponse, worker, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_streaming_response_retrieve(self, async_client: AsyncBrainbase) -> None:
async with async_client.workers.with_streaming_response.retrieve(
@@ -310,7 +312,7 @@ async def test_streaming_response_retrieve(self, async_client: AsyncBrainbase) -
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_path_params_retrieve(self, async_client: AsyncBrainbase) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"):
@@ -318,7 +320,7 @@ async def test_path_params_retrieve(self, async_client: AsyncBrainbase) -> None:
"",
)
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_method_update(self, async_client: AsyncBrainbase) -> None:
worker = await async_client.workers.update(
@@ -326,7 +328,7 @@ async def test_method_update(self, async_client: AsyncBrainbase) -> None:
)
assert_matches_type(WorkerUpdateResponse, worker, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_method_update_with_all_params(self, async_client: AsyncBrainbase) -> None:
worker = await async_client.workers.update(
@@ -336,7 +338,7 @@ async def test_method_update_with_all_params(self, async_client: AsyncBrainbase)
)
assert_matches_type(WorkerUpdateResponse, worker, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_raw_response_update(self, async_client: AsyncBrainbase) -> None:
response = await async_client.workers.with_raw_response.update(
@@ -348,7 +350,7 @@ async def test_raw_response_update(self, async_client: AsyncBrainbase) -> None:
worker = await response.parse()
assert_matches_type(WorkerUpdateResponse, worker, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_streaming_response_update(self, async_client: AsyncBrainbase) -> None:
async with async_client.workers.with_streaming_response.update(
@@ -362,7 +364,7 @@ async def test_streaming_response_update(self, async_client: AsyncBrainbase) ->
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_path_params_update(self, async_client: AsyncBrainbase) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"):
@@ -370,13 +372,13 @@ async def test_path_params_update(self, async_client: AsyncBrainbase) -> None:
id="",
)
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_method_list(self, async_client: AsyncBrainbase) -> None:
worker = await async_client.workers.list()
assert_matches_type(WorkerListResponse, worker, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_raw_response_list(self, async_client: AsyncBrainbase) -> None:
response = await async_client.workers.with_raw_response.list()
@@ -386,7 +388,7 @@ async def test_raw_response_list(self, async_client: AsyncBrainbase) -> None:
worker = await response.parse()
assert_matches_type(WorkerListResponse, worker, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_streaming_response_list(self, async_client: AsyncBrainbase) -> None:
async with async_client.workers.with_streaming_response.list() as response:
@@ -398,7 +400,7 @@ async def test_streaming_response_list(self, async_client: AsyncBrainbase) -> No
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_method_delete(self, async_client: AsyncBrainbase) -> None:
worker = await async_client.workers.delete(
@@ -406,7 +408,7 @@ async def test_method_delete(self, async_client: AsyncBrainbase) -> None:
)
assert worker is None
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_raw_response_delete(self, async_client: AsyncBrainbase) -> None:
response = await async_client.workers.with_raw_response.delete(
@@ -418,7 +420,7 @@ async def test_raw_response_delete(self, async_client: AsyncBrainbase) -> None:
worker = await response.parse()
assert worker is None
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_streaming_response_delete(self, async_client: AsyncBrainbase) -> None:
async with async_client.workers.with_streaming_response.delete(
@@ -432,7 +434,7 @@ async def test_streaming_response_delete(self, async_client: AsyncBrainbase) ->
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_path_params_delete(self, async_client: AsyncBrainbase) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"):
diff --git a/tests/api_resources/workers/deployments/test_voice.py b/tests/api_resources/workers/deployments/test_voice.py
index 7bdabc59..f8b0b122 100644
--- a/tests/api_resources/workers/deployments/test_voice.py
+++ b/tests/api_resources/workers/deployments/test_voice.py
@@ -10,10 +10,8 @@
from brainbase import Brainbase, AsyncBrainbase
from tests.utils import assert_matches_type
from brainbase.types.workers.deployments import (
+ VoiceDeployment,
VoiceListResponse,
- VoiceCreateResponse,
- VoiceUpdateResponse,
- VoiceRetrieveResponse,
)
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
@@ -22,16 +20,16 @@
class TestVoice:
parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_method_create(self, client: Brainbase) -> None:
voice = client.workers.deployments.voice.create(
worker_id="workerId",
name="name",
)
- assert_matches_type(VoiceCreateResponse, voice, path=["response"])
+ assert_matches_type(VoiceDeployment, voice, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_method_create_with_all_params(self, client: Brainbase) -> None:
voice = client.workers.deployments.voice.create(
@@ -41,9 +39,9 @@ def test_method_create_with_all_params(self, client: Brainbase) -> None:
voice_id="voiceId",
voice_provider="voiceProvider",
)
- assert_matches_type(VoiceCreateResponse, voice, path=["response"])
+ assert_matches_type(VoiceDeployment, voice, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_raw_response_create(self, client: Brainbase) -> None:
response = client.workers.deployments.voice.with_raw_response.create(
@@ -54,9 +52,9 @@ def test_raw_response_create(self, client: Brainbase) -> None:
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
voice = response.parse()
- assert_matches_type(VoiceCreateResponse, voice, path=["response"])
+ assert_matches_type(VoiceDeployment, voice, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_streaming_response_create(self, client: Brainbase) -> None:
with client.workers.deployments.voice.with_streaming_response.create(
@@ -67,11 +65,11 @@ def test_streaming_response_create(self, client: Brainbase) -> None:
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
voice = response.parse()
- assert_matches_type(VoiceCreateResponse, voice, path=["response"])
+ assert_matches_type(VoiceDeployment, voice, path=["response"])
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_path_params_create(self, client: Brainbase) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"):
@@ -80,16 +78,16 @@ def test_path_params_create(self, client: Brainbase) -> None:
name="name",
)
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_method_retrieve(self, client: Brainbase) -> None:
voice = client.workers.deployments.voice.retrieve(
deployment_id="deploymentId",
worker_id="workerId",
)
- assert_matches_type(VoiceRetrieveResponse, voice, path=["response"])
+ assert_matches_type(VoiceDeployment, voice, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_raw_response_retrieve(self, client: Brainbase) -> None:
response = client.workers.deployments.voice.with_raw_response.retrieve(
@@ -100,9 +98,9 @@ def test_raw_response_retrieve(self, client: Brainbase) -> None:
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
voice = response.parse()
- assert_matches_type(VoiceRetrieveResponse, voice, path=["response"])
+ assert_matches_type(VoiceDeployment, voice, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_streaming_response_retrieve(self, client: Brainbase) -> None:
with client.workers.deployments.voice.with_streaming_response.retrieve(
@@ -113,11 +111,11 @@ def test_streaming_response_retrieve(self, client: Brainbase) -> None:
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
voice = response.parse()
- assert_matches_type(VoiceRetrieveResponse, voice, path=["response"])
+ assert_matches_type(VoiceDeployment, voice, path=["response"])
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_path_params_retrieve(self, client: Brainbase) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"):
@@ -132,7 +130,7 @@ def test_path_params_retrieve(self, client: Brainbase) -> None:
worker_id="workerId",
)
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_method_update(self, client: Brainbase) -> None:
voice = client.workers.deployments.voice.update(
@@ -140,9 +138,9 @@ def test_method_update(self, client: Brainbase) -> None:
worker_id="workerId",
name="name",
)
- assert_matches_type(VoiceUpdateResponse, voice, path=["response"])
+ assert_matches_type(VoiceDeployment, voice, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_method_update_with_all_params(self, client: Brainbase) -> None:
voice = client.workers.deployments.voice.update(
@@ -153,9 +151,9 @@ def test_method_update_with_all_params(self, client: Brainbase) -> None:
voice_id="voiceId",
voice_provider="voiceProvider",
)
- assert_matches_type(VoiceUpdateResponse, voice, path=["response"])
+ assert_matches_type(VoiceDeployment, voice, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_raw_response_update(self, client: Brainbase) -> None:
response = client.workers.deployments.voice.with_raw_response.update(
@@ -167,9 +165,9 @@ def test_raw_response_update(self, client: Brainbase) -> None:
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
voice = response.parse()
- assert_matches_type(VoiceUpdateResponse, voice, path=["response"])
+ assert_matches_type(VoiceDeployment, voice, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_streaming_response_update(self, client: Brainbase) -> None:
with client.workers.deployments.voice.with_streaming_response.update(
@@ -181,11 +179,11 @@ def test_streaming_response_update(self, client: Brainbase) -> None:
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
voice = response.parse()
- assert_matches_type(VoiceUpdateResponse, voice, path=["response"])
+ assert_matches_type(VoiceDeployment, voice, path=["response"])
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_path_params_update(self, client: Brainbase) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"):
@@ -202,7 +200,7 @@ def test_path_params_update(self, client: Brainbase) -> None:
name="name",
)
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_method_list(self, client: Brainbase) -> None:
voice = client.workers.deployments.voice.list(
@@ -210,7 +208,7 @@ def test_method_list(self, client: Brainbase) -> None:
)
assert_matches_type(VoiceListResponse, voice, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_raw_response_list(self, client: Brainbase) -> None:
response = client.workers.deployments.voice.with_raw_response.list(
@@ -222,7 +220,7 @@ def test_raw_response_list(self, client: Brainbase) -> None:
voice = response.parse()
assert_matches_type(VoiceListResponse, voice, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_streaming_response_list(self, client: Brainbase) -> None:
with client.workers.deployments.voice.with_streaming_response.list(
@@ -236,7 +234,7 @@ def test_streaming_response_list(self, client: Brainbase) -> None:
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_path_params_list(self, client: Brainbase) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"):
@@ -244,7 +242,7 @@ def test_path_params_list(self, client: Brainbase) -> None:
"",
)
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_method_delete(self, client: Brainbase) -> None:
voice = client.workers.deployments.voice.delete(
@@ -253,7 +251,7 @@ def test_method_delete(self, client: Brainbase) -> None:
)
assert voice is None
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_raw_response_delete(self, client: Brainbase) -> None:
response = client.workers.deployments.voice.with_raw_response.delete(
@@ -266,7 +264,7 @@ def test_raw_response_delete(self, client: Brainbase) -> None:
voice = response.parse()
assert voice is None
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_streaming_response_delete(self, client: Brainbase) -> None:
with client.workers.deployments.voice.with_streaming_response.delete(
@@ -281,7 +279,7 @@ def test_streaming_response_delete(self, client: Brainbase) -> None:
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_path_params_delete(self, client: Brainbase) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"):
@@ -298,18 +296,20 @@ def test_path_params_delete(self, client: Brainbase) -> None:
class TestAsyncVoice:
- parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
+ parametrize = pytest.mark.parametrize(
+ "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"]
+ )
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_method_create(self, async_client: AsyncBrainbase) -> None:
voice = await async_client.workers.deployments.voice.create(
worker_id="workerId",
name="name",
)
- assert_matches_type(VoiceCreateResponse, voice, path=["response"])
+ assert_matches_type(VoiceDeployment, voice, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_method_create_with_all_params(self, async_client: AsyncBrainbase) -> None:
voice = await async_client.workers.deployments.voice.create(
@@ -319,9 +319,9 @@ async def test_method_create_with_all_params(self, async_client: AsyncBrainbase)
voice_id="voiceId",
voice_provider="voiceProvider",
)
- assert_matches_type(VoiceCreateResponse, voice, path=["response"])
+ assert_matches_type(VoiceDeployment, voice, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_raw_response_create(self, async_client: AsyncBrainbase) -> None:
response = await async_client.workers.deployments.voice.with_raw_response.create(
@@ -332,9 +332,9 @@ async def test_raw_response_create(self, async_client: AsyncBrainbase) -> None:
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
voice = await response.parse()
- assert_matches_type(VoiceCreateResponse, voice, path=["response"])
+ assert_matches_type(VoiceDeployment, voice, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_streaming_response_create(self, async_client: AsyncBrainbase) -> None:
async with async_client.workers.deployments.voice.with_streaming_response.create(
@@ -345,11 +345,11 @@ async def test_streaming_response_create(self, async_client: AsyncBrainbase) ->
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
voice = await response.parse()
- assert_matches_type(VoiceCreateResponse, voice, path=["response"])
+ assert_matches_type(VoiceDeployment, voice, path=["response"])
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_path_params_create(self, async_client: AsyncBrainbase) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"):
@@ -358,16 +358,16 @@ async def test_path_params_create(self, async_client: AsyncBrainbase) -> None:
name="name",
)
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_method_retrieve(self, async_client: AsyncBrainbase) -> None:
voice = await async_client.workers.deployments.voice.retrieve(
deployment_id="deploymentId",
worker_id="workerId",
)
- assert_matches_type(VoiceRetrieveResponse, voice, path=["response"])
+ assert_matches_type(VoiceDeployment, voice, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_raw_response_retrieve(self, async_client: AsyncBrainbase) -> None:
response = await async_client.workers.deployments.voice.with_raw_response.retrieve(
@@ -378,9 +378,9 @@ async def test_raw_response_retrieve(self, async_client: AsyncBrainbase) -> None
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
voice = await response.parse()
- assert_matches_type(VoiceRetrieveResponse, voice, path=["response"])
+ assert_matches_type(VoiceDeployment, voice, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_streaming_response_retrieve(self, async_client: AsyncBrainbase) -> None:
async with async_client.workers.deployments.voice.with_streaming_response.retrieve(
@@ -391,11 +391,11 @@ async def test_streaming_response_retrieve(self, async_client: AsyncBrainbase) -
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
voice = await response.parse()
- assert_matches_type(VoiceRetrieveResponse, voice, path=["response"])
+ assert_matches_type(VoiceDeployment, voice, path=["response"])
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_path_params_retrieve(self, async_client: AsyncBrainbase) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"):
@@ -410,7 +410,7 @@ async def test_path_params_retrieve(self, async_client: AsyncBrainbase) -> None:
worker_id="workerId",
)
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_method_update(self, async_client: AsyncBrainbase) -> None:
voice = await async_client.workers.deployments.voice.update(
@@ -418,9 +418,9 @@ async def test_method_update(self, async_client: AsyncBrainbase) -> None:
worker_id="workerId",
name="name",
)
- assert_matches_type(VoiceUpdateResponse, voice, path=["response"])
+ assert_matches_type(VoiceDeployment, voice, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_method_update_with_all_params(self, async_client: AsyncBrainbase) -> None:
voice = await async_client.workers.deployments.voice.update(
@@ -431,9 +431,9 @@ async def test_method_update_with_all_params(self, async_client: AsyncBrainbase)
voice_id="voiceId",
voice_provider="voiceProvider",
)
- assert_matches_type(VoiceUpdateResponse, voice, path=["response"])
+ assert_matches_type(VoiceDeployment, voice, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_raw_response_update(self, async_client: AsyncBrainbase) -> None:
response = await async_client.workers.deployments.voice.with_raw_response.update(
@@ -445,9 +445,9 @@ async def test_raw_response_update(self, async_client: AsyncBrainbase) -> None:
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
voice = await response.parse()
- assert_matches_type(VoiceUpdateResponse, voice, path=["response"])
+ assert_matches_type(VoiceDeployment, voice, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_streaming_response_update(self, async_client: AsyncBrainbase) -> None:
async with async_client.workers.deployments.voice.with_streaming_response.update(
@@ -459,11 +459,11 @@ async def test_streaming_response_update(self, async_client: AsyncBrainbase) ->
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
voice = await response.parse()
- assert_matches_type(VoiceUpdateResponse, voice, path=["response"])
+ assert_matches_type(VoiceDeployment, voice, path=["response"])
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_path_params_update(self, async_client: AsyncBrainbase) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"):
@@ -480,7 +480,7 @@ async def test_path_params_update(self, async_client: AsyncBrainbase) -> None:
name="name",
)
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_method_list(self, async_client: AsyncBrainbase) -> None:
voice = await async_client.workers.deployments.voice.list(
@@ -488,7 +488,7 @@ async def test_method_list(self, async_client: AsyncBrainbase) -> None:
)
assert_matches_type(VoiceListResponse, voice, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_raw_response_list(self, async_client: AsyncBrainbase) -> None:
response = await async_client.workers.deployments.voice.with_raw_response.list(
@@ -500,7 +500,7 @@ async def test_raw_response_list(self, async_client: AsyncBrainbase) -> None:
voice = await response.parse()
assert_matches_type(VoiceListResponse, voice, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_streaming_response_list(self, async_client: AsyncBrainbase) -> None:
async with async_client.workers.deployments.voice.with_streaming_response.list(
@@ -514,7 +514,7 @@ async def test_streaming_response_list(self, async_client: AsyncBrainbase) -> No
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_path_params_list(self, async_client: AsyncBrainbase) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"):
@@ -522,7 +522,7 @@ async def test_path_params_list(self, async_client: AsyncBrainbase) -> None:
"",
)
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_method_delete(self, async_client: AsyncBrainbase) -> None:
voice = await async_client.workers.deployments.voice.delete(
@@ -531,7 +531,7 @@ async def test_method_delete(self, async_client: AsyncBrainbase) -> None:
)
assert voice is None
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_raw_response_delete(self, async_client: AsyncBrainbase) -> None:
response = await async_client.workers.deployments.voice.with_raw_response.delete(
@@ -544,7 +544,7 @@ async def test_raw_response_delete(self, async_client: AsyncBrainbase) -> None:
voice = await response.parse()
assert voice is None
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_streaming_response_delete(self, async_client: AsyncBrainbase) -> None:
async with async_client.workers.deployments.voice.with_streaming_response.delete(
@@ -559,7 +559,7 @@ async def test_streaming_response_delete(self, async_client: AsyncBrainbase) ->
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_path_params_delete(self, async_client: AsyncBrainbase) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"):
diff --git a/tests/api_resources/workers/test_flows.py b/tests/api_resources/workers/test_flows.py
index e4d645ca..93cc926a 100644
--- a/tests/api_resources/workers/test_flows.py
+++ b/tests/api_resources/workers/test_flows.py
@@ -22,7 +22,7 @@
class TestFlows:
parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_method_create(self, client: Brainbase) -> None:
flow = client.workers.flows.create(
@@ -32,7 +32,7 @@ def test_method_create(self, client: Brainbase) -> None:
)
assert_matches_type(FlowCreateResponse, flow, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_method_create_with_all_params(self, client: Brainbase) -> None:
flow = client.workers.flows.create(
@@ -43,7 +43,7 @@ def test_method_create_with_all_params(self, client: Brainbase) -> None:
)
assert_matches_type(FlowCreateResponse, flow, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_raw_response_create(self, client: Brainbase) -> None:
response = client.workers.flows.with_raw_response.create(
@@ -57,7 +57,7 @@ def test_raw_response_create(self, client: Brainbase) -> None:
flow = response.parse()
assert_matches_type(FlowCreateResponse, flow, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_streaming_response_create(self, client: Brainbase) -> None:
with client.workers.flows.with_streaming_response.create(
@@ -73,7 +73,7 @@ def test_streaming_response_create(self, client: Brainbase) -> None:
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_path_params_create(self, client: Brainbase) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"):
@@ -83,7 +83,7 @@ def test_path_params_create(self, client: Brainbase) -> None:
name="name",
)
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_method_retrieve(self, client: Brainbase) -> None:
flow = client.workers.flows.retrieve(
@@ -92,7 +92,7 @@ def test_method_retrieve(self, client: Brainbase) -> None:
)
assert_matches_type(FlowRetrieveResponse, flow, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_raw_response_retrieve(self, client: Brainbase) -> None:
response = client.workers.flows.with_raw_response.retrieve(
@@ -105,7 +105,7 @@ def test_raw_response_retrieve(self, client: Brainbase) -> None:
flow = response.parse()
assert_matches_type(FlowRetrieveResponse, flow, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_streaming_response_retrieve(self, client: Brainbase) -> None:
with client.workers.flows.with_streaming_response.retrieve(
@@ -120,7 +120,7 @@ def test_streaming_response_retrieve(self, client: Brainbase) -> None:
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_path_params_retrieve(self, client: Brainbase) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"):
@@ -135,7 +135,7 @@ def test_path_params_retrieve(self, client: Brainbase) -> None:
worker_id="workerId",
)
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_method_update(self, client: Brainbase) -> None:
flow = client.workers.flows.update(
@@ -144,7 +144,7 @@ def test_method_update(self, client: Brainbase) -> None:
)
assert_matches_type(FlowUpdateResponse, flow, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_method_update_with_all_params(self, client: Brainbase) -> None:
flow = client.workers.flows.update(
@@ -156,7 +156,7 @@ def test_method_update_with_all_params(self, client: Brainbase) -> None:
)
assert_matches_type(FlowUpdateResponse, flow, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_raw_response_update(self, client: Brainbase) -> None:
response = client.workers.flows.with_raw_response.update(
@@ -169,7 +169,7 @@ def test_raw_response_update(self, client: Brainbase) -> None:
flow = response.parse()
assert_matches_type(FlowUpdateResponse, flow, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_streaming_response_update(self, client: Brainbase) -> None:
with client.workers.flows.with_streaming_response.update(
@@ -184,7 +184,7 @@ def test_streaming_response_update(self, client: Brainbase) -> None:
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_path_params_update(self, client: Brainbase) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"):
@@ -199,7 +199,7 @@ def test_path_params_update(self, client: Brainbase) -> None:
worker_id="workerId",
)
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_method_list(self, client: Brainbase) -> None:
flow = client.workers.flows.list(
@@ -207,7 +207,7 @@ def test_method_list(self, client: Brainbase) -> None:
)
assert_matches_type(FlowListResponse, flow, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_raw_response_list(self, client: Brainbase) -> None:
response = client.workers.flows.with_raw_response.list(
@@ -219,7 +219,7 @@ def test_raw_response_list(self, client: Brainbase) -> None:
flow = response.parse()
assert_matches_type(FlowListResponse, flow, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_streaming_response_list(self, client: Brainbase) -> None:
with client.workers.flows.with_streaming_response.list(
@@ -233,7 +233,7 @@ def test_streaming_response_list(self, client: Brainbase) -> None:
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_path_params_list(self, client: Brainbase) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"):
@@ -241,7 +241,7 @@ def test_path_params_list(self, client: Brainbase) -> None:
"",
)
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_method_delete(self, client: Brainbase) -> None:
flow = client.workers.flows.delete(
@@ -250,7 +250,7 @@ def test_method_delete(self, client: Brainbase) -> None:
)
assert flow is None
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_raw_response_delete(self, client: Brainbase) -> None:
response = client.workers.flows.with_raw_response.delete(
@@ -263,7 +263,7 @@ def test_raw_response_delete(self, client: Brainbase) -> None:
flow = response.parse()
assert flow is None
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_streaming_response_delete(self, client: Brainbase) -> None:
with client.workers.flows.with_streaming_response.delete(
@@ -278,7 +278,7 @@ def test_streaming_response_delete(self, client: Brainbase) -> None:
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_path_params_delete(self, client: Brainbase) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"):
@@ -295,9 +295,11 @@ def test_path_params_delete(self, client: Brainbase) -> None:
class TestAsyncFlows:
- parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
+ parametrize = pytest.mark.parametrize(
+ "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"]
+ )
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_method_create(self, async_client: AsyncBrainbase) -> None:
flow = await async_client.workers.flows.create(
@@ -307,7 +309,7 @@ async def test_method_create(self, async_client: AsyncBrainbase) -> None:
)
assert_matches_type(FlowCreateResponse, flow, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_method_create_with_all_params(self, async_client: AsyncBrainbase) -> None:
flow = await async_client.workers.flows.create(
@@ -318,7 +320,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncBrainbase)
)
assert_matches_type(FlowCreateResponse, flow, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_raw_response_create(self, async_client: AsyncBrainbase) -> None:
response = await async_client.workers.flows.with_raw_response.create(
@@ -332,7 +334,7 @@ async def test_raw_response_create(self, async_client: AsyncBrainbase) -> None:
flow = await response.parse()
assert_matches_type(FlowCreateResponse, flow, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_streaming_response_create(self, async_client: AsyncBrainbase) -> None:
async with async_client.workers.flows.with_streaming_response.create(
@@ -348,7 +350,7 @@ async def test_streaming_response_create(self, async_client: AsyncBrainbase) ->
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_path_params_create(self, async_client: AsyncBrainbase) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"):
@@ -358,7 +360,7 @@ async def test_path_params_create(self, async_client: AsyncBrainbase) -> None:
name="name",
)
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_method_retrieve(self, async_client: AsyncBrainbase) -> None:
flow = await async_client.workers.flows.retrieve(
@@ -367,7 +369,7 @@ async def test_method_retrieve(self, async_client: AsyncBrainbase) -> None:
)
assert_matches_type(FlowRetrieveResponse, flow, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_raw_response_retrieve(self, async_client: AsyncBrainbase) -> None:
response = await async_client.workers.flows.with_raw_response.retrieve(
@@ -380,7 +382,7 @@ async def test_raw_response_retrieve(self, async_client: AsyncBrainbase) -> None
flow = await response.parse()
assert_matches_type(FlowRetrieveResponse, flow, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_streaming_response_retrieve(self, async_client: AsyncBrainbase) -> None:
async with async_client.workers.flows.with_streaming_response.retrieve(
@@ -395,7 +397,7 @@ async def test_streaming_response_retrieve(self, async_client: AsyncBrainbase) -
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_path_params_retrieve(self, async_client: AsyncBrainbase) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"):
@@ -410,7 +412,7 @@ async def test_path_params_retrieve(self, async_client: AsyncBrainbase) -> None:
worker_id="workerId",
)
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_method_update(self, async_client: AsyncBrainbase) -> None:
flow = await async_client.workers.flows.update(
@@ -419,7 +421,7 @@ async def test_method_update(self, async_client: AsyncBrainbase) -> None:
)
assert_matches_type(FlowUpdateResponse, flow, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_method_update_with_all_params(self, async_client: AsyncBrainbase) -> None:
flow = await async_client.workers.flows.update(
@@ -431,7 +433,7 @@ async def test_method_update_with_all_params(self, async_client: AsyncBrainbase)
)
assert_matches_type(FlowUpdateResponse, flow, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_raw_response_update(self, async_client: AsyncBrainbase) -> None:
response = await async_client.workers.flows.with_raw_response.update(
@@ -444,7 +446,7 @@ async def test_raw_response_update(self, async_client: AsyncBrainbase) -> None:
flow = await response.parse()
assert_matches_type(FlowUpdateResponse, flow, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_streaming_response_update(self, async_client: AsyncBrainbase) -> None:
async with async_client.workers.flows.with_streaming_response.update(
@@ -459,7 +461,7 @@ async def test_streaming_response_update(self, async_client: AsyncBrainbase) ->
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_path_params_update(self, async_client: AsyncBrainbase) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"):
@@ -474,7 +476,7 @@ async def test_path_params_update(self, async_client: AsyncBrainbase) -> None:
worker_id="workerId",
)
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_method_list(self, async_client: AsyncBrainbase) -> None:
flow = await async_client.workers.flows.list(
@@ -482,7 +484,7 @@ async def test_method_list(self, async_client: AsyncBrainbase) -> None:
)
assert_matches_type(FlowListResponse, flow, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_raw_response_list(self, async_client: AsyncBrainbase) -> None:
response = await async_client.workers.flows.with_raw_response.list(
@@ -494,7 +496,7 @@ async def test_raw_response_list(self, async_client: AsyncBrainbase) -> None:
flow = await response.parse()
assert_matches_type(FlowListResponse, flow, path=["response"])
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_streaming_response_list(self, async_client: AsyncBrainbase) -> None:
async with async_client.workers.flows.with_streaming_response.list(
@@ -508,7 +510,7 @@ async def test_streaming_response_list(self, async_client: AsyncBrainbase) -> No
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_path_params_list(self, async_client: AsyncBrainbase) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"):
@@ -516,7 +518,7 @@ async def test_path_params_list(self, async_client: AsyncBrainbase) -> None:
"",
)
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_method_delete(self, async_client: AsyncBrainbase) -> None:
flow = await async_client.workers.flows.delete(
@@ -525,7 +527,7 @@ async def test_method_delete(self, async_client: AsyncBrainbase) -> None:
)
assert flow is None
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_raw_response_delete(self, async_client: AsyncBrainbase) -> None:
response = await async_client.workers.flows.with_raw_response.delete(
@@ -538,7 +540,7 @@ async def test_raw_response_delete(self, async_client: AsyncBrainbase) -> None:
flow = await response.parse()
assert flow is None
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_streaming_response_delete(self, async_client: AsyncBrainbase) -> None:
async with async_client.workers.flows.with_streaming_response.delete(
@@ -553,7 +555,7 @@ async def test_streaming_response_delete(self, async_client: AsyncBrainbase) ->
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
+ @pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_path_params_delete(self, async_client: AsyncBrainbase) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `worker_id` but received ''"):
diff --git a/tests/conftest.py b/tests/conftest.py
index 8e89d982..7a6d4de1 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,16 +1,20 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
from __future__ import annotations
import os
import logging
from typing import TYPE_CHECKING, Iterator, AsyncIterator
+import httpx
import pytest
from pytest_asyncio import is_async_test
-from brainbase import Brainbase, AsyncBrainbase
+from brainbase import Brainbase, AsyncBrainbase, DefaultAioHttpClient
+from brainbase._utils import is_dict
if TYPE_CHECKING:
- from _pytest.fixtures import FixtureRequest
+ from _pytest.fixtures import FixtureRequest # pyright: ignore[reportPrivateImportUsage]
pytest.register_assert_rewrite("tests.utils")
@@ -25,6 +29,19 @@ def pytest_collection_modifyitems(items: list[pytest.Function]) -> None:
for async_test in pytest_asyncio_tests:
async_test.add_marker(session_scope_marker, append=False)
+ # We skip tests that use both the aiohttp client and respx_mock as respx_mock
+ # doesn't support custom transports.
+ for item in items:
+ if "async_client" not in item.fixturenames or "respx_mock" not in item.fixturenames:
+ continue
+
+ if not hasattr(item, "callspec"):
+ continue
+
+ async_client_param = item.callspec.params.get("async_client")
+ if is_dict(async_client_param) and async_client_param.get("http_client") == "aiohttp":
+ item.add_marker(pytest.mark.skip(reason="aiohttp client is not compatible with respx_mock"))
+
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
@@ -43,9 +60,25 @@ def client(request: FixtureRequest) -> Iterator[Brainbase]:
@pytest.fixture(scope="session")
async def async_client(request: FixtureRequest) -> AsyncIterator[AsyncBrainbase]:
- strict = getattr(request, "param", True)
- if not isinstance(strict, bool):
- raise TypeError(f"Unexpected fixture parameter type {type(strict)}, expected {bool}")
-
- async with AsyncBrainbase(base_url=base_url, api_key=api_key, _strict_response_validation=strict) as client:
+ param = getattr(request, "param", True)
+
+ # defaults
+ strict = True
+ http_client: None | httpx.AsyncClient = None
+
+ if isinstance(param, bool):
+ strict = param
+ elif is_dict(param):
+ strict = param.get("strict", True)
+ assert isinstance(strict, bool)
+
+ http_client_type = param.get("http_client", "httpx")
+ if http_client_type == "aiohttp":
+ http_client = DefaultAioHttpClient()
+ else:
+ raise TypeError(f"Unexpected fixture parameter type {type(param)}, expected bool or dict")
+
+ async with AsyncBrainbase(
+ base_url=base_url, api_key=api_key, _strict_response_validation=strict, http_client=http_client
+ ) as client:
yield client
diff --git a/tests/test_client.py b/tests/test_client.py
index c3570edb..6f35f9d0 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -6,13 +6,10 @@
import os
import sys
import json
-import time
import asyncio
import inspect
-import subprocess
import tracemalloc
from typing import Any, Union, cast
-from textwrap import dedent
from unittest import mock
from typing_extensions import Literal
@@ -23,13 +20,17 @@
from brainbase import Brainbase, AsyncBrainbase, APIResponseValidationError
from brainbase._types import Omit
+from brainbase._utils import asyncify
from brainbase._models import BaseModel, FinalRequestOptions
-from brainbase._constants import RAW_RESPONSE_HEADER
from brainbase._exceptions import APIStatusError, BrainbaseError, APITimeoutError, APIResponseValidationError
from brainbase._base_client import (
DEFAULT_TIMEOUT,
HTTPX_DEFAULT_TIMEOUT,
BaseClient,
+ OtherPlatform,
+ DefaultHttpxClient,
+ DefaultAsyncHttpxClient,
+ get_platform,
make_request_options,
)
@@ -58,51 +59,49 @@ def _get_open_connections(client: Brainbase | AsyncBrainbase) -> int:
class TestBrainbase:
- client = Brainbase(base_url=base_url, api_key=api_key, _strict_response_validation=True)
-
@pytest.mark.respx(base_url=base_url)
- def test_raw_response(self, respx_mock: MockRouter) -> None:
+ def test_raw_response(self, respx_mock: MockRouter, client: Brainbase) -> None:
respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
- response = self.client.post("/foo", cast_to=httpx.Response)
+ response = client.post("/foo", cast_to=httpx.Response)
assert response.status_code == 200
assert isinstance(response, httpx.Response)
assert response.json() == {"foo": "bar"}
@pytest.mark.respx(base_url=base_url)
- def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
+ def test_raw_response_for_binary(self, respx_mock: MockRouter, client: Brainbase) -> None:
respx_mock.post("/foo").mock(
return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
)
- response = self.client.post("/foo", cast_to=httpx.Response)
+ response = client.post("/foo", cast_to=httpx.Response)
assert response.status_code == 200
assert isinstance(response, httpx.Response)
assert response.json() == {"foo": "bar"}
- def test_copy(self) -> None:
- copied = self.client.copy()
- assert id(copied) != id(self.client)
+ def test_copy(self, client: Brainbase) -> None:
+ copied = client.copy()
+ assert id(copied) != id(client)
- copied = self.client.copy(api_key="another My API Key")
+ copied = client.copy(api_key="another My API Key")
assert copied.api_key == "another My API Key"
- assert self.client.api_key == "My API Key"
+ assert client.api_key == "My API Key"
- def test_copy_default_options(self) -> None:
+ def test_copy_default_options(self, client: Brainbase) -> None:
# options that have a default are overridden correctly
- copied = self.client.copy(max_retries=7)
+ copied = client.copy(max_retries=7)
assert copied.max_retries == 7
- assert self.client.max_retries == 2
+ assert client.max_retries == 2
copied2 = copied.copy(max_retries=6)
assert copied2.max_retries == 6
assert copied.max_retries == 7
# timeout
- assert isinstance(self.client.timeout, httpx.Timeout)
- copied = self.client.copy(timeout=None)
+ assert isinstance(client.timeout, httpx.Timeout)
+ copied = client.copy(timeout=None)
assert copied.timeout is None
- assert isinstance(self.client.timeout, httpx.Timeout)
+ assert isinstance(client.timeout, httpx.Timeout)
def test_copy_default_headers(self) -> None:
client = Brainbase(
@@ -137,6 +136,7 @@ def test_copy_default_headers(self) -> None:
match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
):
client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
+ client.close()
def test_copy_default_query(self) -> None:
client = Brainbase(
@@ -174,13 +174,15 @@ def test_copy_default_query(self) -> None:
):
client.copy(set_default_query={}, default_query={"foo": "Bar"})
- def test_copy_signature(self) -> None:
+ client.close()
+
+ def test_copy_signature(self, client: Brainbase) -> None:
# ensure the same parameters that can be passed to the client are defined in the `.copy()` method
init_signature = inspect.signature(
# mypy doesn't like that we access the `__init__` property.
- self.client.__init__, # type: ignore[misc]
+ client.__init__, # type: ignore[misc]
)
- copy_signature = inspect.signature(self.client.copy)
+ copy_signature = inspect.signature(client.copy)
exclude_params = {"transport", "proxies", "_strict_response_validation"}
for name in init_signature.parameters.keys():
@@ -190,12 +192,13 @@ def test_copy_signature(self) -> None:
copy_param = copy_signature.parameters.get(name)
assert copy_param is not None, f"copy() signature is missing the {name} param"
- def test_copy_build_request(self) -> None:
+ @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12")
+ def test_copy_build_request(self, client: Brainbase) -> None:
options = FinalRequestOptions(method="get", url="/foo")
def build_request(options: FinalRequestOptions) -> None:
- client = self.client.copy()
- client._build_request(options)
+ client_copy = client.copy()
+ client_copy._build_request(options)
# ensure that the machinery is warmed up before tracing starts.
build_request(options)
@@ -252,14 +255,12 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic
print(frame)
raise AssertionError()
- def test_request_timeout(self) -> None:
- request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ def test_request_timeout(self, client: Brainbase) -> None:
+ request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == DEFAULT_TIMEOUT
- request = self.client._build_request(
- FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
- )
+ request = client._build_request(FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0)))
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == httpx.Timeout(100.0)
@@ -272,6 +273,8 @@ def test_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == httpx.Timeout(0)
+ client.close()
+
def test_http_client_timeout_option(self) -> None:
# custom timeout given to the httpx client should be used
with httpx.Client(timeout=None) as http_client:
@@ -283,6 +286,8 @@ def test_http_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == httpx.Timeout(None)
+ client.close()
+
# no timeout given to the httpx client should not use the httpx default
with httpx.Client() as http_client:
client = Brainbase(
@@ -293,6 +298,8 @@ def test_http_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == DEFAULT_TIMEOUT
+ client.close()
+
# explicitly passing the default timeout currently results in it being ignored
with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
client = Brainbase(
@@ -303,6 +310,8 @@ def test_http_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == DEFAULT_TIMEOUT # our default
+ client.close()
+
async def test_invalid_http_client(self) -> None:
with pytest.raises(TypeError, match="Invalid `http_client` arg"):
async with httpx.AsyncClient() as http_client:
@@ -314,14 +323,14 @@ async def test_invalid_http_client(self) -> None:
)
def test_default_headers_option(self) -> None:
- client = Brainbase(
+ test_client = Brainbase(
base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
)
- request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ request = test_client._build_request(FinalRequestOptions(method="get", url="/foo"))
assert request.headers.get("x-foo") == "bar"
assert request.headers.get("x-stainless-lang") == "python"
- client2 = Brainbase(
+ test_client2 = Brainbase(
base_url=base_url,
api_key=api_key,
_strict_response_validation=True,
@@ -330,10 +339,13 @@ def test_default_headers_option(self) -> None:
"X-Stainless-Lang": "my-overriding-header",
},
)
- request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
+ request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo"))
assert request.headers.get("x-foo") == "stainless"
assert request.headers.get("x-stainless-lang") == "my-overriding-header"
+ test_client.close()
+ test_client2.close()
+
def test_validate_headers(self) -> None:
client = Brainbase(base_url=base_url, api_key=api_key, _strict_response_validation=True)
request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
@@ -362,8 +374,10 @@ def test_default_query_option(self) -> None:
url = httpx.URL(request.url)
assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
- def test_request_extra_json(self) -> None:
- request = self.client._build_request(
+ client.close()
+
+ def test_request_extra_json(self, client: Brainbase) -> None:
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -374,7 +388,7 @@ def test_request_extra_json(self) -> None:
data = json.loads(request.content.decode("utf-8"))
assert data == {"foo": "bar", "baz": False}
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -385,7 +399,7 @@ def test_request_extra_json(self) -> None:
assert data == {"baz": False}
# `extra_json` takes priority over `json_data` when keys clash
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -396,8 +410,8 @@ def test_request_extra_json(self) -> None:
data = json.loads(request.content.decode("utf-8"))
assert data == {"foo": "bar", "baz": None}
- def test_request_extra_headers(self) -> None:
- request = self.client._build_request(
+ def test_request_extra_headers(self, client: Brainbase) -> None:
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -407,7 +421,7 @@ def test_request_extra_headers(self) -> None:
assert request.headers.get("X-Foo") == "Foo"
# `extra_headers` takes priority over `default_headers` when keys clash
- request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
+ request = client.with_options(default_headers={"X-Bar": "true"})._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -418,8 +432,8 @@ def test_request_extra_headers(self) -> None:
)
assert request.headers.get("X-Bar") == "false"
- def test_request_extra_query(self) -> None:
- request = self.client._build_request(
+ def test_request_extra_query(self, client: Brainbase) -> None:
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -432,7 +446,7 @@ def test_request_extra_query(self) -> None:
assert params == {"my_query_param": "Foo"}
# if both `query` and `extra_query` are given, they are merged
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -446,7 +460,7 @@ def test_request_extra_query(self) -> None:
assert params == {"bar": "1", "foo": "2"}
# `extra_query` takes priority over `query` when keys clash
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -462,7 +476,7 @@ def test_request_extra_query(self) -> None:
def test_multipart_repeating_array(self, client: Brainbase) -> None:
request = client._build_request(
FinalRequestOptions.construct(
- method="get",
+ method="post",
url="/foo",
headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
json_data={"array": ["foo", "bar"]},
@@ -489,7 +503,7 @@ def test_multipart_repeating_array(self, client: Brainbase) -> None:
]
@pytest.mark.respx(base_url=base_url)
- def test_basic_union_response(self, respx_mock: MockRouter) -> None:
+ def test_basic_union_response(self, respx_mock: MockRouter, client: Brainbase) -> None:
class Model1(BaseModel):
name: str
@@ -498,12 +512,12 @@ class Model2(BaseModel):
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
- response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
+ response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
assert isinstance(response, Model2)
assert response.foo == "bar"
@pytest.mark.respx(base_url=base_url)
- def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
+ def test_union_response_different_types(self, respx_mock: MockRouter, client: Brainbase) -> None:
"""Union of objects with the same field name using a different type"""
class Model1(BaseModel):
@@ -514,18 +528,18 @@ class Model2(BaseModel):
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
- response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
+ response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
assert isinstance(response, Model2)
assert response.foo == "bar"
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
- response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
+ response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
assert isinstance(response, Model1)
assert response.foo == 1
@pytest.mark.respx(base_url=base_url)
- def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
+ def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter, client: Brainbase) -> None:
"""
Response that sets Content-Type to something other than application/json but returns json data
"""
@@ -541,7 +555,7 @@ class Model(BaseModel):
)
)
- response = self.client.get("/foo", cast_to=Model)
+ response = client.get("/foo", cast_to=Model)
assert isinstance(response, Model)
assert response.foo == 2
@@ -553,6 +567,8 @@ def test_base_url_setter(self) -> None:
assert client.base_url == "https://example.com/from_setter/"
+ client.close()
+
def test_base_url_env(self) -> None:
with update_env(BRAINBASE_BASE_URL="http://localhost:5000/from/env"):
client = Brainbase(api_key=api_key, _strict_response_validation=True)
@@ -580,6 +596,7 @@ def test_base_url_trailing_slash(self, client: Brainbase) -> None:
),
)
assert request.url == "http://localhost:5000/custom/path/foo"
+ client.close()
@pytest.mark.parametrize(
"client",
@@ -603,6 +620,7 @@ def test_base_url_no_trailing_slash(self, client: Brainbase) -> None:
),
)
assert request.url == "http://localhost:5000/custom/path/foo"
+ client.close()
@pytest.mark.parametrize(
"client",
@@ -626,35 +644,36 @@ def test_absolute_request_url(self, client: Brainbase) -> None:
),
)
assert request.url == "https://myapi.com/foo"
+ client.close()
def test_copied_client_does_not_close_http(self) -> None:
- client = Brainbase(base_url=base_url, api_key=api_key, _strict_response_validation=True)
- assert not client.is_closed()
+ test_client = Brainbase(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ assert not test_client.is_closed()
- copied = client.copy()
- assert copied is not client
+ copied = test_client.copy()
+ assert copied is not test_client
del copied
- assert not client.is_closed()
+ assert not test_client.is_closed()
def test_client_context_manager(self) -> None:
- client = Brainbase(base_url=base_url, api_key=api_key, _strict_response_validation=True)
- with client as c2:
- assert c2 is client
+ test_client = Brainbase(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ with test_client as c2:
+ assert c2 is test_client
assert not c2.is_closed()
- assert not client.is_closed()
- assert client.is_closed()
+ assert not test_client.is_closed()
+ assert test_client.is_closed()
@pytest.mark.respx(base_url=base_url)
- def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
+ def test_client_response_validation_error(self, respx_mock: MockRouter, client: Brainbase) -> None:
class Model(BaseModel):
foo: str
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
with pytest.raises(APIResponseValidationError) as exc:
- self.client.get("/foo", cast_to=Model)
+ client.get("/foo", cast_to=Model)
assert isinstance(exc.value.__cause__, ValidationError)
@@ -674,11 +693,14 @@ class Model(BaseModel):
with pytest.raises(APIResponseValidationError):
strict_client.get("/foo", cast_to=Model)
- client = Brainbase(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ non_strict_client = Brainbase(base_url=base_url, api_key=api_key, _strict_response_validation=False)
- response = client.get("/foo", cast_to=Model)
+ response = non_strict_client.get("/foo", cast_to=Model)
assert isinstance(response, str) # type: ignore[unreachable]
+ strict_client.close()
+ non_strict_client.close()
+
@pytest.mark.parametrize(
"remaining_retries,retry_after,timeout",
[
@@ -701,9 +723,9 @@ class Model(BaseModel):
],
)
@mock.patch("time.time", mock.MagicMock(return_value=1696004797))
- def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
- client = Brainbase(base_url=base_url, api_key=api_key, _strict_response_validation=True)
-
+ def test_parse_retry_after_header(
+ self, remaining_retries: int, retry_after: str, timeout: float, client: Brainbase
+ ) -> None:
headers = httpx.Headers({"retry-after": retry_after})
options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
@@ -711,27 +733,22 @@ def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str
@mock.patch("brainbase._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
- def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
+ def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, client: Brainbase) -> None:
respx_mock.get("/api/workers").mock(side_effect=httpx.TimeoutException("Test timeout error"))
with pytest.raises(APITimeoutError):
- self.client.get(
- "/api/workers", cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}
- )
+ client.workers.with_streaming_response.list().__enter__()
- assert _get_open_connections(self.client) == 0
+ assert _get_open_connections(client) == 0
@mock.patch("brainbase._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
- def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
+ def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client: Brainbase) -> None:
respx_mock.get("/api/workers").mock(return_value=httpx.Response(500))
with pytest.raises(APIStatusError):
- self.client.get(
- "/api/workers", cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}
- )
-
- assert _get_open_connections(self.client) == 0
+ client.workers.with_streaming_response.list().__enter__()
+ assert _get_open_connections(client) == 0
@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
@mock.patch("brainbase._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@@ -810,57 +827,100 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
assert response.http_request.headers.get("x-stainless-retry-count") == "42"
+ def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None:
+ # Test that the proxy environment variables are set correctly
+ monkeypatch.setenv("HTTPS_PROXY", "https://example.org")
+
+ client = DefaultHttpxClient()
+
+ mounts = tuple(client._mounts.items())
+ assert len(mounts) == 1
+ assert mounts[0][0].pattern == "https://"
+
+ @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning")
+ def test_default_client_creation(self) -> None:
+ # Ensure that the client can be initialized without any exceptions
+ DefaultHttpxClient(
+ verify=True,
+ cert=None,
+ trust_env=True,
+ http1=True,
+ http2=False,
+ limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
+ )
-class TestAsyncBrainbase:
- client = AsyncBrainbase(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ @pytest.mark.respx(base_url=base_url)
+ def test_follow_redirects(self, respx_mock: MockRouter, client: Brainbase) -> None:
+ # Test that the default follow_redirects=True allows following redirects
+ respx_mock.post("/redirect").mock(
+ return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
+ )
+ respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
+
+ response = client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
+ assert response.status_code == 200
+ assert response.json() == {"status": "ok"}
+
+ @pytest.mark.respx(base_url=base_url)
+ def test_follow_redirects_disabled(self, respx_mock: MockRouter, client: Brainbase) -> None:
+ # Test that follow_redirects=False prevents following redirects
+ respx_mock.post("/redirect").mock(
+ return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
+ )
+ with pytest.raises(APIStatusError) as exc_info:
+ client.post("/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response)
+
+ assert exc_info.value.response.status_code == 302
+ assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
+
+
+class TestAsyncBrainbase:
@pytest.mark.respx(base_url=base_url)
- @pytest.mark.asyncio
- async def test_raw_response(self, respx_mock: MockRouter) -> None:
+ async def test_raw_response(self, respx_mock: MockRouter, async_client: AsyncBrainbase) -> None:
respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
- response = await self.client.post("/foo", cast_to=httpx.Response)
+ response = await async_client.post("/foo", cast_to=httpx.Response)
assert response.status_code == 200
assert isinstance(response, httpx.Response)
assert response.json() == {"foo": "bar"}
@pytest.mark.respx(base_url=base_url)
- @pytest.mark.asyncio
- async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
+ async def test_raw_response_for_binary(self, respx_mock: MockRouter, async_client: AsyncBrainbase) -> None:
respx_mock.post("/foo").mock(
return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
)
- response = await self.client.post("/foo", cast_to=httpx.Response)
+ response = await async_client.post("/foo", cast_to=httpx.Response)
assert response.status_code == 200
assert isinstance(response, httpx.Response)
assert response.json() == {"foo": "bar"}
- def test_copy(self) -> None:
- copied = self.client.copy()
- assert id(copied) != id(self.client)
+ def test_copy(self, async_client: AsyncBrainbase) -> None:
+ copied = async_client.copy()
+ assert id(copied) != id(async_client)
- copied = self.client.copy(api_key="another My API Key")
+ copied = async_client.copy(api_key="another My API Key")
assert copied.api_key == "another My API Key"
- assert self.client.api_key == "My API Key"
+ assert async_client.api_key == "My API Key"
- def test_copy_default_options(self) -> None:
+ def test_copy_default_options(self, async_client: AsyncBrainbase) -> None:
# options that have a default are overridden correctly
- copied = self.client.copy(max_retries=7)
+ copied = async_client.copy(max_retries=7)
assert copied.max_retries == 7
- assert self.client.max_retries == 2
+ assert async_client.max_retries == 2
copied2 = copied.copy(max_retries=6)
assert copied2.max_retries == 6
assert copied.max_retries == 7
# timeout
- assert isinstance(self.client.timeout, httpx.Timeout)
- copied = self.client.copy(timeout=None)
+ assert isinstance(async_client.timeout, httpx.Timeout)
+ copied = async_client.copy(timeout=None)
assert copied.timeout is None
- assert isinstance(self.client.timeout, httpx.Timeout)
+ assert isinstance(async_client.timeout, httpx.Timeout)
- def test_copy_default_headers(self) -> None:
+ async def test_copy_default_headers(self) -> None:
client = AsyncBrainbase(
base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
)
@@ -893,8 +953,9 @@ def test_copy_default_headers(self) -> None:
match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
):
client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
+ await client.close()
- def test_copy_default_query(self) -> None:
+ async def test_copy_default_query(self) -> None:
client = AsyncBrainbase(
base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
)
@@ -930,13 +991,15 @@ def test_copy_default_query(self) -> None:
):
client.copy(set_default_query={}, default_query={"foo": "Bar"})
- def test_copy_signature(self) -> None:
+ await client.close()
+
+ def test_copy_signature(self, async_client: AsyncBrainbase) -> None:
# ensure the same parameters that can be passed to the client are defined in the `.copy()` method
init_signature = inspect.signature(
# mypy doesn't like that we access the `__init__` property.
- self.client.__init__, # type: ignore[misc]
+ async_client.__init__, # type: ignore[misc]
)
- copy_signature = inspect.signature(self.client.copy)
+ copy_signature = inspect.signature(async_client.copy)
exclude_params = {"transport", "proxies", "_strict_response_validation"}
for name in init_signature.parameters.keys():
@@ -946,12 +1009,13 @@ def test_copy_signature(self) -> None:
copy_param = copy_signature.parameters.get(name)
assert copy_param is not None, f"copy() signature is missing the {name} param"
- def test_copy_build_request(self) -> None:
+ @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12")
+ def test_copy_build_request(self, async_client: AsyncBrainbase) -> None:
options = FinalRequestOptions(method="get", url="/foo")
def build_request(options: FinalRequestOptions) -> None:
- client = self.client.copy()
- client._build_request(options)
+ client_copy = async_client.copy()
+ client_copy._build_request(options)
# ensure that the machinery is warmed up before tracing starts.
build_request(options)
@@ -1008,12 +1072,12 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic
print(frame)
raise AssertionError()
- async def test_request_timeout(self) -> None:
- request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ async def test_request_timeout(self, async_client: AsyncBrainbase) -> None:
+ request = async_client._build_request(FinalRequestOptions(method="get", url="/foo"))
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == DEFAULT_TIMEOUT
- request = self.client._build_request(
+ request = async_client._build_request(
FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
)
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
@@ -1028,6 +1092,8 @@ async def test_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == httpx.Timeout(0)
+ await client.close()
+
async def test_http_client_timeout_option(self) -> None:
# custom timeout given to the httpx client should be used
async with httpx.AsyncClient(timeout=None) as http_client:
@@ -1039,6 +1105,8 @@ async def test_http_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == httpx.Timeout(None)
+ await client.close()
+
# no timeout given to the httpx client should not use the httpx default
async with httpx.AsyncClient() as http_client:
client = AsyncBrainbase(
@@ -1049,6 +1117,8 @@ async def test_http_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == DEFAULT_TIMEOUT
+ await client.close()
+
# explicitly passing the default timeout currently results in it being ignored
async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
client = AsyncBrainbase(
@@ -1059,6 +1129,8 @@ async def test_http_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == DEFAULT_TIMEOUT # our default
+ await client.close()
+
def test_invalid_http_client(self) -> None:
with pytest.raises(TypeError, match="Invalid `http_client` arg"):
with httpx.Client() as http_client:
@@ -1069,15 +1141,15 @@ def test_invalid_http_client(self) -> None:
http_client=cast(Any, http_client),
)
- def test_default_headers_option(self) -> None:
- client = AsyncBrainbase(
+ async def test_default_headers_option(self) -> None:
+ test_client = AsyncBrainbase(
base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
)
- request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ request = test_client._build_request(FinalRequestOptions(method="get", url="/foo"))
assert request.headers.get("x-foo") == "bar"
assert request.headers.get("x-stainless-lang") == "python"
- client2 = AsyncBrainbase(
+ test_client2 = AsyncBrainbase(
base_url=base_url,
api_key=api_key,
_strict_response_validation=True,
@@ -1086,10 +1158,13 @@ def test_default_headers_option(self) -> None:
"X-Stainless-Lang": "my-overriding-header",
},
)
- request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
+ request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo"))
assert request.headers.get("x-foo") == "stainless"
assert request.headers.get("x-stainless-lang") == "my-overriding-header"
+ await test_client.close()
+ await test_client2.close()
+
def test_validate_headers(self) -> None:
client = AsyncBrainbase(base_url=base_url, api_key=api_key, _strict_response_validation=True)
request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
@@ -1100,7 +1175,7 @@ def test_validate_headers(self) -> None:
client2 = AsyncBrainbase(base_url=base_url, api_key=None, _strict_response_validation=True)
_ = client2
- def test_default_query_option(self) -> None:
+ async def test_default_query_option(self) -> None:
client = AsyncBrainbase(
base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
)
@@ -1118,8 +1193,10 @@ def test_default_query_option(self) -> None:
url = httpx.URL(request.url)
assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
- def test_request_extra_json(self) -> None:
- request = self.client._build_request(
+ await client.close()
+
+ def test_request_extra_json(self, client: Brainbase) -> None:
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1130,7 +1207,7 @@ def test_request_extra_json(self) -> None:
data = json.loads(request.content.decode("utf-8"))
assert data == {"foo": "bar", "baz": False}
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1141,7 +1218,7 @@ def test_request_extra_json(self) -> None:
assert data == {"baz": False}
# `extra_json` takes priority over `json_data` when keys clash
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1152,8 +1229,8 @@ def test_request_extra_json(self) -> None:
data = json.loads(request.content.decode("utf-8"))
assert data == {"foo": "bar", "baz": None}
- def test_request_extra_headers(self) -> None:
- request = self.client._build_request(
+ def test_request_extra_headers(self, client: Brainbase) -> None:
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1163,7 +1240,7 @@ def test_request_extra_headers(self) -> None:
assert request.headers.get("X-Foo") == "Foo"
# `extra_headers` takes priority over `default_headers` when keys clash
- request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
+ request = client.with_options(default_headers={"X-Bar": "true"})._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1174,8 +1251,8 @@ def test_request_extra_headers(self) -> None:
)
assert request.headers.get("X-Bar") == "false"
- def test_request_extra_query(self) -> None:
- request = self.client._build_request(
+ def test_request_extra_query(self, client: Brainbase) -> None:
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1188,7 +1265,7 @@ def test_request_extra_query(self) -> None:
assert params == {"my_query_param": "Foo"}
# if both `query` and `extra_query` are given, they are merged
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1202,7 +1279,7 @@ def test_request_extra_query(self) -> None:
assert params == {"bar": "1", "foo": "2"}
# `extra_query` takes priority over `query` when keys clash
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1218,7 +1295,7 @@ def test_request_extra_query(self) -> None:
def test_multipart_repeating_array(self, async_client: AsyncBrainbase) -> None:
request = async_client._build_request(
FinalRequestOptions.construct(
- method="get",
+ method="post",
url="/foo",
headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
json_data={"array": ["foo", "bar"]},
@@ -1245,7 +1322,7 @@ def test_multipart_repeating_array(self, async_client: AsyncBrainbase) -> None:
]
@pytest.mark.respx(base_url=base_url)
- async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
+ async def test_basic_union_response(self, respx_mock: MockRouter, async_client: AsyncBrainbase) -> None:
class Model1(BaseModel):
name: str
@@ -1254,12 +1331,12 @@ class Model2(BaseModel):
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
- response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
+ response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
assert isinstance(response, Model2)
assert response.foo == "bar"
@pytest.mark.respx(base_url=base_url)
- async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
+ async def test_union_response_different_types(self, respx_mock: MockRouter, async_client: AsyncBrainbase) -> None:
"""Union of objects with the same field name using a different type"""
class Model1(BaseModel):
@@ -1270,18 +1347,20 @@ class Model2(BaseModel):
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
- response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
+ response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
assert isinstance(response, Model2)
assert response.foo == "bar"
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
- response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
+ response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
assert isinstance(response, Model1)
assert response.foo == 1
@pytest.mark.respx(base_url=base_url)
- async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
+ async def test_non_application_json_content_type_for_json_data(
+ self, respx_mock: MockRouter, async_client: AsyncBrainbase
+ ) -> None:
"""
Response that sets Content-Type to something other than application/json but returns json data
"""
@@ -1297,11 +1376,11 @@ class Model(BaseModel):
)
)
- response = await self.client.get("/foo", cast_to=Model)
+ response = await async_client.get("/foo", cast_to=Model)
assert isinstance(response, Model)
assert response.foo == 2
- def test_base_url_setter(self) -> None:
+ async def test_base_url_setter(self) -> None:
client = AsyncBrainbase(
base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
)
@@ -1311,7 +1390,9 @@ def test_base_url_setter(self) -> None:
assert client.base_url == "https://example.com/from_setter/"
- def test_base_url_env(self) -> None:
+ await client.close()
+
+ async def test_base_url_env(self) -> None:
with update_env(BRAINBASE_BASE_URL="http://localhost:5000/from/env"):
client = AsyncBrainbase(api_key=api_key, _strict_response_validation=True)
assert client.base_url == "http://localhost:5000/from/env/"
@@ -1331,7 +1412,7 @@ def test_base_url_env(self) -> None:
],
ids=["standard", "custom http client"],
)
- def test_base_url_trailing_slash(self, client: AsyncBrainbase) -> None:
+ async def test_base_url_trailing_slash(self, client: AsyncBrainbase) -> None:
request = client._build_request(
FinalRequestOptions(
method="post",
@@ -1340,6 +1421,7 @@ def test_base_url_trailing_slash(self, client: AsyncBrainbase) -> None:
),
)
assert request.url == "http://localhost:5000/custom/path/foo"
+ await client.close()
@pytest.mark.parametrize(
"client",
@@ -1356,7 +1438,7 @@ def test_base_url_trailing_slash(self, client: AsyncBrainbase) -> None:
],
ids=["standard", "custom http client"],
)
- def test_base_url_no_trailing_slash(self, client: AsyncBrainbase) -> None:
+ async def test_base_url_no_trailing_slash(self, client: AsyncBrainbase) -> None:
request = client._build_request(
FinalRequestOptions(
method="post",
@@ -1365,6 +1447,7 @@ def test_base_url_no_trailing_slash(self, client: AsyncBrainbase) -> None:
),
)
assert request.url == "http://localhost:5000/custom/path/foo"
+ await client.close()
@pytest.mark.parametrize(
"client",
@@ -1381,7 +1464,7 @@ def test_base_url_no_trailing_slash(self, client: AsyncBrainbase) -> None:
],
ids=["standard", "custom http client"],
)
- def test_absolute_request_url(self, client: AsyncBrainbase) -> None:
+ async def test_absolute_request_url(self, client: AsyncBrainbase) -> None:
request = client._build_request(
FinalRequestOptions(
method="post",
@@ -1390,37 +1473,37 @@ def test_absolute_request_url(self, client: AsyncBrainbase) -> None:
),
)
assert request.url == "https://myapi.com/foo"
+ await client.close()
async def test_copied_client_does_not_close_http(self) -> None:
- client = AsyncBrainbase(base_url=base_url, api_key=api_key, _strict_response_validation=True)
- assert not client.is_closed()
+ test_client = AsyncBrainbase(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ assert not test_client.is_closed()
- copied = client.copy()
- assert copied is not client
+ copied = test_client.copy()
+ assert copied is not test_client
del copied
await asyncio.sleep(0.2)
- assert not client.is_closed()
+ assert not test_client.is_closed()
async def test_client_context_manager(self) -> None:
- client = AsyncBrainbase(base_url=base_url, api_key=api_key, _strict_response_validation=True)
- async with client as c2:
- assert c2 is client
+ test_client = AsyncBrainbase(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ async with test_client as c2:
+ assert c2 is test_client
assert not c2.is_closed()
- assert not client.is_closed()
- assert client.is_closed()
+ assert not test_client.is_closed()
+ assert test_client.is_closed()
@pytest.mark.respx(base_url=base_url)
- @pytest.mark.asyncio
- async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
+ async def test_client_response_validation_error(self, respx_mock: MockRouter, async_client: AsyncBrainbase) -> None:
class Model(BaseModel):
foo: str
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
with pytest.raises(APIResponseValidationError) as exc:
- await self.client.get("/foo", cast_to=Model)
+ await async_client.get("/foo", cast_to=Model)
assert isinstance(exc.value.__cause__, ValidationError)
@@ -1431,7 +1514,6 @@ async def test_client_max_retries_validation(self) -> None:
)
@pytest.mark.respx(base_url=base_url)
- @pytest.mark.asyncio
async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
class Model(BaseModel):
name: str
@@ -1443,11 +1525,14 @@ class Model(BaseModel):
with pytest.raises(APIResponseValidationError):
await strict_client.get("/foo", cast_to=Model)
- client = AsyncBrainbase(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ non_strict_client = AsyncBrainbase(base_url=base_url, api_key=api_key, _strict_response_validation=False)
- response = await client.get("/foo", cast_to=Model)
+ response = await non_strict_client.get("/foo", cast_to=Model)
assert isinstance(response, str) # type: ignore[unreachable]
+ await strict_client.close()
+ await non_strict_client.close()
+
@pytest.mark.parametrize(
"remaining_retries,retry_after,timeout",
[
@@ -1470,43 +1555,40 @@ class Model(BaseModel):
],
)
@mock.patch("time.time", mock.MagicMock(return_value=1696004797))
- @pytest.mark.asyncio
- async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
- client = AsyncBrainbase(base_url=base_url, api_key=api_key, _strict_response_validation=True)
-
+ async def test_parse_retry_after_header(
+ self, remaining_retries: int, retry_after: str, timeout: float, async_client: AsyncBrainbase
+ ) -> None:
headers = httpx.Headers({"retry-after": retry_after})
options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
- calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
+ calculated = async_client._calculate_retry_timeout(remaining_retries, options, headers)
assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
@mock.patch("brainbase._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
- async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
+ async def test_retrying_timeout_errors_doesnt_leak(
+ self, respx_mock: MockRouter, async_client: AsyncBrainbase
+ ) -> None:
respx_mock.get("/api/workers").mock(side_effect=httpx.TimeoutException("Test timeout error"))
with pytest.raises(APITimeoutError):
- await self.client.get(
- "/api/workers", cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}
- )
+ await async_client.workers.with_streaming_response.list().__aenter__()
- assert _get_open_connections(self.client) == 0
+ assert _get_open_connections(async_client) == 0
@mock.patch("brainbase._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
- async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
+ async def test_retrying_status_errors_doesnt_leak(
+ self, respx_mock: MockRouter, async_client: AsyncBrainbase
+ ) -> None:
respx_mock.get("/api/workers").mock(return_value=httpx.Response(500))
with pytest.raises(APIStatusError):
- await self.client.get(
- "/api/workers", cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}
- )
-
- assert _get_open_connections(self.client) == 0
+ await async_client.workers.with_streaming_response.list().__aenter__()
+ assert _get_open_connections(async_client) == 0
@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
@mock.patch("brainbase._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
- @pytest.mark.asyncio
@pytest.mark.parametrize("failure_mode", ["status", "exception"])
async def test_retries_taken(
self,
@@ -1538,7 +1620,6 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
@mock.patch("brainbase._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
- @pytest.mark.asyncio
async def test_omit_retry_count_header(
self, async_client: AsyncBrainbase, failures_before_success: int, respx_mock: MockRouter
) -> None:
@@ -1562,7 +1643,6 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
@mock.patch("brainbase._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
- @pytest.mark.asyncio
async def test_overwrite_retry_count_header(
self, async_client: AsyncBrainbase, failures_before_success: int, respx_mock: MockRouter
) -> None:
@@ -1583,47 +1663,55 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
assert response.http_request.headers.get("x-stainless-retry-count") == "42"
- def test_get_platform(self) -> None:
- # A previous implementation of asyncify could leave threads unterminated when
- # used with nest_asyncio.
- #
- # Since nest_asyncio.apply() is global and cannot be un-applied, this
- # test is run in a separate process to avoid affecting other tests.
- test_code = dedent("""
- import asyncio
- import nest_asyncio
- import threading
-
- from brainbase._utils import asyncify
- from brainbase._base_client import get_platform
-
- async def test_main() -> None:
- result = await asyncify(get_platform)()
- print(result)
- for thread in threading.enumerate():
- print(thread.name)
-
- nest_asyncio.apply()
- asyncio.run(test_main())
- """)
- with subprocess.Popen(
- [sys.executable, "-c", test_code],
- text=True,
- ) as process:
- timeout = 10 # seconds
-
- start_time = time.monotonic()
- while True:
- return_code = process.poll()
- if return_code is not None:
- if return_code != 0:
- raise AssertionError("calling get_platform using asyncify resulted in a non-zero exit code")
-
- # success
- break
-
- if time.monotonic() - start_time > timeout:
- process.kill()
- raise AssertionError("calling get_platform using asyncify resulted in a hung process")
-
- time.sleep(0.1)
+ async def test_get_platform(self) -> None:
+ platform = await asyncify(get_platform)()
+ assert isinstance(platform, (str, OtherPlatform))
+
+ async def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None:
+ # Test that the proxy environment variables are set correctly
+ monkeypatch.setenv("HTTPS_PROXY", "https://example.org")
+
+ client = DefaultAsyncHttpxClient()
+
+ mounts = tuple(client._mounts.items())
+ assert len(mounts) == 1
+ assert mounts[0][0].pattern == "https://"
+
+ @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning")
+ async def test_default_client_creation(self) -> None:
+ # Ensure that the client can be initialized without any exceptions
+ DefaultAsyncHttpxClient(
+ verify=True,
+ cert=None,
+ trust_env=True,
+ http1=True,
+ http2=False,
+ limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
+ )
+
+ @pytest.mark.respx(base_url=base_url)
+ async def test_follow_redirects(self, respx_mock: MockRouter, async_client: AsyncBrainbase) -> None:
+ # Test that the default follow_redirects=True allows following redirects
+ respx_mock.post("/redirect").mock(
+ return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
+ )
+ respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
+
+ response = await async_client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
+ assert response.status_code == 200
+ assert response.json() == {"status": "ok"}
+
+ @pytest.mark.respx(base_url=base_url)
+ async def test_follow_redirects_disabled(self, respx_mock: MockRouter, async_client: AsyncBrainbase) -> None:
+ # Test that follow_redirects=False prevents following redirects
+ respx_mock.post("/redirect").mock(
+ return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
+ )
+
+ with pytest.raises(APIStatusError) as exc_info:
+ await async_client.post(
+ "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response
+ )
+
+ assert exc_info.value.response.status_code == 302
+ assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
diff --git a/tests/test_models.py b/tests/test_models.py
index 1063c832..0471c6c9 100644
--- a/tests/test_models.py
+++ b/tests/test_models.py
@@ -1,5 +1,5 @@
import json
-from typing import Any, Dict, List, Union, Optional, cast
+from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional, cast
from datetime import datetime, timezone
from typing_extensions import Literal, Annotated, TypeAliasType
@@ -8,8 +8,8 @@
from pydantic import Field
from brainbase._utils import PropertyInfo
-from brainbase._compat import PYDANTIC_V2, parse_obj, model_dump, model_json
-from brainbase._models import BaseModel, construct_type
+from brainbase._compat import PYDANTIC_V1, parse_obj, model_dump, model_json
+from brainbase._models import DISCRIMINATOR_CACHE, BaseModel, construct_type
class BasicModel(BaseModel):
@@ -294,12 +294,12 @@ class Model(BaseModel):
assert cast(bool, m.foo) is True
m = Model.construct(foo={"name": 3})
- if PYDANTIC_V2:
- assert isinstance(m.foo, Submodel1)
- assert m.foo.name == 3 # type: ignore
- else:
+ if PYDANTIC_V1:
assert isinstance(m.foo, Submodel2)
assert m.foo.name == "3"
+ else:
+ assert isinstance(m.foo, Submodel1)
+ assert m.foo.name == 3 # type: ignore
def test_list_of_unions() -> None:
@@ -426,10 +426,10 @@ class Model(BaseModel):
expected = datetime(2019, 12, 27, 18, 11, 19, 117000, tzinfo=timezone.utc)
- if PYDANTIC_V2:
- expected_json = '{"created_at":"2019-12-27T18:11:19.117000Z"}'
- else:
+ if PYDANTIC_V1:
expected_json = '{"created_at": "2019-12-27T18:11:19.117000+00:00"}'
+ else:
+ expected_json = '{"created_at":"2019-12-27T18:11:19.117000Z"}'
model = Model.construct(created_at="2019-12-27T18:11:19.117Z")
assert model.created_at == expected
@@ -492,12 +492,15 @@ class Model(BaseModel):
resource_id: Optional[str] = None
m = Model.construct()
+ assert m.resource_id is None
assert "resource_id" not in m.model_fields_set
m = Model.construct(resource_id=None)
+ assert m.resource_id is None
assert "resource_id" in m.model_fields_set
m = Model.construct(resource_id="foo")
+ assert m.resource_id == "foo"
assert "resource_id" in m.model_fields_set
@@ -528,7 +531,7 @@ class Model2(BaseModel):
assert m4.to_dict(mode="python") == {"created_at": datetime.fromisoformat(time_str)}
assert m4.to_dict(mode="json") == {"created_at": time_str}
- if not PYDANTIC_V2:
+ if PYDANTIC_V1:
with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"):
m.to_dict(warnings=False)
@@ -553,7 +556,7 @@ class Model(BaseModel):
assert m3.model_dump() == {"foo": None}
assert m3.model_dump(exclude_none=True) == {}
- if not PYDANTIC_V2:
+ if PYDANTIC_V1:
with pytest.raises(ValueError, match="round_trip is only supported in Pydantic v2"):
m.model_dump(round_trip=True)
@@ -577,10 +580,10 @@ class Model(BaseModel):
assert json.loads(m.to_json()) == {"FOO": "hello"}
assert json.loads(m.to_json(use_api_names=False)) == {"foo": "hello"}
- if PYDANTIC_V2:
- assert m.to_json(indent=None) == '{"FOO":"hello"}'
- else:
+ if PYDANTIC_V1:
assert m.to_json(indent=None) == '{"FOO": "hello"}'
+ else:
+ assert m.to_json(indent=None) == '{"FOO":"hello"}'
m2 = Model()
assert json.loads(m2.to_json()) == {}
@@ -592,7 +595,7 @@ class Model(BaseModel):
assert json.loads(m3.to_json()) == {"FOO": None}
assert json.loads(m3.to_json(exclude_none=True)) == {}
- if not PYDANTIC_V2:
+ if PYDANTIC_V1:
with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"):
m.to_json(warnings=False)
@@ -619,7 +622,7 @@ class Model(BaseModel):
assert json.loads(m3.model_dump_json()) == {"foo": None}
assert json.loads(m3.model_dump_json(exclude_none=True)) == {}
- if not PYDANTIC_V2:
+ if PYDANTIC_V1:
with pytest.raises(ValueError, match="round_trip is only supported in Pydantic v2"):
m.model_dump_json(round_trip=True)
@@ -676,12 +679,12 @@ class B(BaseModel):
)
assert isinstance(m, A)
assert m.type == "a"
- if PYDANTIC_V2:
- assert m.data == 100 # type: ignore[comparison-overlap]
- else:
+ if PYDANTIC_V1:
# pydantic v1 automatically converts inputs to strings
# if the expected type is a str
assert m.data == "100"
+ else:
+ assert m.data == 100 # type: ignore[comparison-overlap]
def test_discriminated_unions_unknown_variant() -> None:
@@ -765,12 +768,12 @@ class B(BaseModel):
)
assert isinstance(m, A)
assert m.foo_type == "a"
- if PYDANTIC_V2:
- assert m.data == 100 # type: ignore[comparison-overlap]
- else:
+ if PYDANTIC_V1:
# pydantic v1 automatically converts inputs to strings
# if the expected type is a str
assert m.data == "100"
+ else:
+ assert m.data == 100 # type: ignore[comparison-overlap]
def test_discriminated_unions_overlapping_discriminators_invalid_data() -> None:
@@ -806,7 +809,7 @@ class B(BaseModel):
UnionType = cast(Any, Union[A, B])
- assert not hasattr(UnionType, "__discriminator__")
+ assert not DISCRIMINATOR_CACHE.get(UnionType)
m = construct_type(
value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")])
@@ -815,7 +818,7 @@ class B(BaseModel):
assert m.type == "b"
assert m.data == "foo" # type: ignore[comparison-overlap]
- discriminator = UnionType.__discriminator__
+ discriminator = DISCRIMINATOR_CACHE.get(UnionType)
assert discriminator is not None
m = construct_type(
@@ -827,12 +830,12 @@ class B(BaseModel):
# if the discriminator details object stays the same between invocations then
# we hit the cache
- assert UnionType.__discriminator__ is discriminator
+ assert DISCRIMINATOR_CACHE.get(UnionType) is discriminator
-@pytest.mark.skipif(not PYDANTIC_V2, reason="TypeAliasType is not supported in Pydantic v1")
+@pytest.mark.skipif(PYDANTIC_V1, reason="TypeAliasType is not supported in Pydantic v1")
def test_type_alias_type() -> None:
- Alias = TypeAliasType("Alias", str)
+ Alias = TypeAliasType("Alias", str) # pyright: ignore
class Model(BaseModel):
alias: Alias
@@ -846,7 +849,7 @@ class Model(BaseModel):
assert m.union == "bar"
-@pytest.mark.skipif(not PYDANTIC_V2, reason="TypeAliasType is not supported in Pydantic v1")
+@pytest.mark.skipif(PYDANTIC_V1, reason="TypeAliasType is not supported in Pydantic v1")
def test_field_named_cls() -> None:
class Model(BaseModel):
cls: str
@@ -854,3 +857,107 @@ class Model(BaseModel):
m = construct_type(value={"cls": "foo"}, type_=Model)
assert isinstance(m, Model)
assert isinstance(m.cls, str)
+
+
+def test_discriminated_union_case() -> None:
+ class A(BaseModel):
+ type: Literal["a"]
+
+ data: bool
+
+ class B(BaseModel):
+ type: Literal["b"]
+
+ data: List[Union[A, object]]
+
+ class ModelA(BaseModel):
+ type: Literal["modelA"]
+
+ data: int
+
+ class ModelB(BaseModel):
+ type: Literal["modelB"]
+
+ required: str
+
+ data: Union[A, B]
+
+ # when constructing ModelA | ModelB, value data doesn't match ModelB exactly - missing `required`
+ m = construct_type(
+ value={"type": "modelB", "data": {"type": "a", "data": True}},
+ type_=cast(Any, Annotated[Union[ModelA, ModelB], PropertyInfo(discriminator="type")]),
+ )
+
+ assert isinstance(m, ModelB)
+
+
+def test_nested_discriminated_union() -> None:
+ class InnerType1(BaseModel):
+ type: Literal["type_1"]
+
+ class InnerModel(BaseModel):
+ inner_value: str
+
+ class InnerType2(BaseModel):
+ type: Literal["type_2"]
+ some_inner_model: InnerModel
+
+ class Type1(BaseModel):
+ base_type: Literal["base_type_1"]
+ value: Annotated[
+ Union[
+ InnerType1,
+ InnerType2,
+ ],
+ PropertyInfo(discriminator="type"),
+ ]
+
+ class Type2(BaseModel):
+ base_type: Literal["base_type_2"]
+
+ T = Annotated[
+ Union[
+ Type1,
+ Type2,
+ ],
+ PropertyInfo(discriminator="base_type"),
+ ]
+
+ model = construct_type(
+ type_=T,
+ value={
+ "base_type": "base_type_1",
+ "value": {
+ "type": "type_2",
+ },
+ },
+ )
+ assert isinstance(model, Type1)
+ assert isinstance(model.value, InnerType2)
+
+
+@pytest.mark.skipif(PYDANTIC_V1, reason="this is only supported in pydantic v2 for now")
+def test_extra_properties() -> None:
+ class Item(BaseModel):
+ prop: int
+
+ class Model(BaseModel):
+ __pydantic_extra__: Dict[str, Item] = Field(init=False) # pyright: ignore[reportIncompatibleVariableOverride]
+
+ other: str
+
+ if TYPE_CHECKING:
+
+ def __getattr__(self, attr: str) -> Item: ...
+
+ model = construct_type(
+ type_=Model,
+ value={
+ "a": {"prop": 1},
+ "other": "foo",
+ },
+ )
+ assert isinstance(model, Model)
+ assert model.a.prop == 1
+ assert isinstance(model.a, Item)
+ assert model.other == "foo"
diff --git a/tests/test_transform.py b/tests/test_transform.py
index 2293b288..9023ce5d 100644
--- a/tests/test_transform.py
+++ b/tests/test_transform.py
@@ -2,20 +2,20 @@
import io
import pathlib
-from typing import Any, List, Union, TypeVar, Iterable, Optional, cast
+from typing import Any, Dict, List, Union, TypeVar, Iterable, Optional, cast
from datetime import date, datetime
from typing_extensions import Required, Annotated, TypedDict
import pytest
-from brainbase._types import Base64FileInput
+from brainbase._types import Base64FileInput, omit, not_given
from brainbase._utils import (
PropertyInfo,
transform as _transform,
parse_datetime,
async_transform as _async_transform,
)
-from brainbase._compat import PYDANTIC_V2
+from brainbase._compat import PYDANTIC_V1
from brainbase._models import BaseModel
_T = TypeVar("_T")
@@ -189,7 +189,7 @@ class DateModel(BaseModel):
@pytest.mark.asyncio
async def test_iso8601_format(use_async: bool) -> None:
dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00")
- tz = "Z" if PYDANTIC_V2 else "+00:00"
+ tz = "+00:00" if PYDANTIC_V1 else "Z"
assert await transform({"foo": dt}, DatetimeDict, use_async) == {"foo": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap]
assert await transform(DatetimeModel(foo=dt), Any, use_async) == {"foo": "2023-02-23T14:16:36.337692" + tz} # type: ignore[comparison-overlap]
@@ -297,11 +297,11 @@ async def test_pydantic_unknown_field(use_async: bool) -> None:
@pytest.mark.asyncio
async def test_pydantic_mismatched_types(use_async: bool) -> None:
model = MyModel.construct(foo=True)
- if PYDANTIC_V2:
+ if PYDANTIC_V1:
+ params = await transform(model, Any, use_async)
+ else:
with pytest.warns(UserWarning):
params = await transform(model, Any, use_async)
- else:
- params = await transform(model, Any, use_async)
assert cast(Any, params) == {"foo": True}
@@ -309,11 +309,11 @@ async def test_pydantic_mismatched_types(use_async: bool) -> None:
@pytest.mark.asyncio
async def test_pydantic_mismatched_object_type(use_async: bool) -> None:
model = MyModel.construct(foo=MyModel.construct(hello="world"))
- if PYDANTIC_V2:
+ if PYDANTIC_V1:
+ params = await transform(model, Any, use_async)
+ else:
with pytest.warns(UserWarning):
params = await transform(model, Any, use_async)
- else:
- params = await transform(model, Any, use_async)
assert cast(Any, params) == {"foo": {"hello": "world"}}
@@ -388,6 +388,15 @@ def my_iter() -> Iterable[Baz8]:
}
+@parametrize
+@pytest.mark.asyncio
+async def test_dictionary_items(use_async: bool) -> None:
+ class DictItems(TypedDict):
+ foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")]
+
+ assert await transform({"foo": {"foo_baz": "bar"}}, Dict[str, DictItems], use_async) == {"foo": {"fooBaz": "bar"}}
+
+
class TypedDictIterableUnionStr(TypedDict):
foo: Annotated[Union[str, Iterable[Baz8]], PropertyInfo(alias="FOO")]
@@ -423,3 +432,29 @@ async def test_base64_file_input(use_async: bool) -> None:
assert await transform({"foo": io.BytesIO(b"Hello, world!")}, TypedDictBase64Input, use_async) == {
"foo": "SGVsbG8sIHdvcmxkIQ=="
} # type: ignore[comparison-overlap]
+
+
+@parametrize
+@pytest.mark.asyncio
+async def test_transform_skipping(use_async: bool) -> None:
+ # lists of ints are left as-is
+ data = [1, 2, 3]
+ assert await transform(data, List[int], use_async) is data
+
+ # iterables of ints are converted to a list
+ data = iter([1, 2, 3])
+ assert await transform(data, Iterable[int], use_async) == [1, 2, 3]
+
+
+@parametrize
+@pytest.mark.asyncio
+async def test_strips_notgiven(use_async: bool) -> None:
+ assert await transform({"foo_bar": "bar"}, Foo1, use_async) == {"fooBar": "bar"}
+ assert await transform({"foo_bar": not_given}, Foo1, use_async) == {}
+
+
+@parametrize
+@pytest.mark.asyncio
+async def test_strips_omit(use_async: bool) -> None:
+ assert await transform({"foo_bar": "bar"}, Foo1, use_async) == {"fooBar": "bar"}
+ assert await transform({"foo_bar": omit}, Foo1, use_async) == {}
diff --git a/tests/test_utils/test_datetime_parse.py b/tests/test_utils/test_datetime_parse.py
new file mode 100644
index 00000000..be6c09ec
--- /dev/null
+++ b/tests/test_utils/test_datetime_parse.py
@@ -0,0 +1,110 @@
+"""
+Copied from https://github.com/pydantic/pydantic/blob/v1.10.22/tests/test_datetime_parse.py
+with modifications so it works without pydantic v1 imports.
+"""
+
+from typing import Type, Union
+from datetime import date, datetime, timezone, timedelta
+
+import pytest
+
+from brainbase._utils import parse_date, parse_datetime
+
+
+def create_tz(minutes: int) -> timezone:
+ return timezone(timedelta(minutes=minutes))
+
+
+@pytest.mark.parametrize(
+ "value,result",
+ [
+ # Valid inputs
+ ("1494012444.883309", date(2017, 5, 5)),
+ (b"1494012444.883309", date(2017, 5, 5)),
+ (1_494_012_444.883_309, date(2017, 5, 5)),
+ ("1494012444", date(2017, 5, 5)),
+ (1_494_012_444, date(2017, 5, 5)),
+ (0, date(1970, 1, 1)),
+ ("2012-04-23", date(2012, 4, 23)),
+ (b"2012-04-23", date(2012, 4, 23)),
+ ("2012-4-9", date(2012, 4, 9)),
+ (date(2012, 4, 9), date(2012, 4, 9)),
+ (datetime(2012, 4, 9, 12, 15), date(2012, 4, 9)),
+ # Invalid inputs
+ ("x20120423", ValueError),
+ ("2012-04-56", ValueError),
+ (19_999_999_999, date(2603, 10, 11)), # just before watershed
+ (20_000_000_001, date(1970, 8, 20)), # just after watershed
+ (1_549_316_052, date(2019, 2, 4)), # nowish in s
+ (1_549_316_052_104, date(2019, 2, 4)), # nowish in ms
+ (1_549_316_052_104_324, date(2019, 2, 4)), # nowish in μs
+ (1_549_316_052_104_324_096, date(2019, 2, 4)), # nowish in ns
+ ("infinity", date(9999, 12, 31)),
+ ("inf", date(9999, 12, 31)),
+ (float("inf"), date(9999, 12, 31)),
+ ("infinity ", date(9999, 12, 31)),
+ (int("1" + "0" * 100), date(9999, 12, 31)),
+ (1e1000, date(9999, 12, 31)),
+ ("-infinity", date(1, 1, 1)),
+ ("-inf", date(1, 1, 1)),
+ ("nan", ValueError),
+ ],
+)
+def test_date_parsing(value: Union[str, bytes, int, float], result: Union[date, Type[Exception]]) -> None:
+ if type(result) == type and issubclass(result, Exception): # pyright: ignore[reportUnnecessaryIsInstance]
+ with pytest.raises(result):
+ parse_date(value)
+ else:
+ assert parse_date(value) == result
+
+
+@pytest.mark.parametrize(
+ "value,result",
+ [
+ # Valid inputs
+ # values in seconds
+ ("1494012444.883309", datetime(2017, 5, 5, 19, 27, 24, 883_309, tzinfo=timezone.utc)),
+ (1_494_012_444.883_309, datetime(2017, 5, 5, 19, 27, 24, 883_309, tzinfo=timezone.utc)),
+ ("1494012444", datetime(2017, 5, 5, 19, 27, 24, tzinfo=timezone.utc)),
+ (b"1494012444", datetime(2017, 5, 5, 19, 27, 24, tzinfo=timezone.utc)),
+ (1_494_012_444, datetime(2017, 5, 5, 19, 27, 24, tzinfo=timezone.utc)),
+ # values in ms
+ ("1494012444000.883309", datetime(2017, 5, 5, 19, 27, 24, 883, tzinfo=timezone.utc)),
+ ("-1494012444000.883309", datetime(1922, 8, 29, 4, 32, 35, 999117, tzinfo=timezone.utc)),
+ (1_494_012_444_000, datetime(2017, 5, 5, 19, 27, 24, tzinfo=timezone.utc)),
+ ("2012-04-23T09:15:00", datetime(2012, 4, 23, 9, 15)),
+ ("2012-4-9 4:8:16", datetime(2012, 4, 9, 4, 8, 16)),
+ ("2012-04-23T09:15:00Z", datetime(2012, 4, 23, 9, 15, 0, 0, timezone.utc)),
+ ("2012-4-9 4:8:16-0320", datetime(2012, 4, 9, 4, 8, 16, 0, create_tz(-200))),
+ ("2012-04-23T10:20:30.400+02:30", datetime(2012, 4, 23, 10, 20, 30, 400_000, create_tz(150))),
+ ("2012-04-23T10:20:30.400+02", datetime(2012, 4, 23, 10, 20, 30, 400_000, create_tz(120))),
+ ("2012-04-23T10:20:30.400-02", datetime(2012, 4, 23, 10, 20, 30, 400_000, create_tz(-120))),
+ (b"2012-04-23T10:20:30.400-02", datetime(2012, 4, 23, 10, 20, 30, 400_000, create_tz(-120))),
+ (datetime(2017, 5, 5), datetime(2017, 5, 5)),
+ (0, datetime(1970, 1, 1, 0, 0, 0, tzinfo=timezone.utc)),
+ # Invalid inputs
+ ("x20120423091500", ValueError),
+ ("2012-04-56T09:15:90", ValueError),
+ ("2012-04-23T11:05:00-25:00", ValueError),
+ (19_999_999_999, datetime(2603, 10, 11, 11, 33, 19, tzinfo=timezone.utc)), # just before watershed
+ (20_000_000_001, datetime(1970, 8, 20, 11, 33, 20, 1000, tzinfo=timezone.utc)), # just after watershed
+ (1_549_316_052, datetime(2019, 2, 4, 21, 34, 12, 0, tzinfo=timezone.utc)), # nowish in s
+ (1_549_316_052_104, datetime(2019, 2, 4, 21, 34, 12, 104_000, tzinfo=timezone.utc)), # nowish in ms
+ (1_549_316_052_104_324, datetime(2019, 2, 4, 21, 34, 12, 104_324, tzinfo=timezone.utc)), # nowish in μs
+ (1_549_316_052_104_324_096, datetime(2019, 2, 4, 21, 34, 12, 104_324, tzinfo=timezone.utc)), # nowish in ns
+ ("infinity", datetime(9999, 12, 31, 23, 59, 59, 999999)),
+ ("inf", datetime(9999, 12, 31, 23, 59, 59, 999999)),
+ ("inf ", datetime(9999, 12, 31, 23, 59, 59, 999999)),
+ (1e50, datetime(9999, 12, 31, 23, 59, 59, 999999)),
+ (float("inf"), datetime(9999, 12, 31, 23, 59, 59, 999999)),
+ ("-infinity", datetime(1, 1, 1, 0, 0)),
+ ("-inf", datetime(1, 1, 1, 0, 0)),
+ ("nan", ValueError),
+ ],
+)
+def test_datetime_parsing(value: Union[str, bytes, int, float], result: Union[datetime, Type[Exception]]) -> None:
+ if type(result) == type and issubclass(result, Exception): # pyright: ignore[reportUnnecessaryIsInstance]
+ with pytest.raises(result):
+ parse_datetime(value)
+ else:
+ assert parse_datetime(value) == result
diff --git a/tests/test_utils/test_proxy.py b/tests/test_utils/test_proxy.py
index d79d1bb5..92ae06db 100644
--- a/tests/test_utils/test_proxy.py
+++ b/tests/test_utils/test_proxy.py
@@ -21,3 +21,14 @@ def test_recursive_proxy() -> None:
assert dir(proxy) == []
assert type(proxy).__name__ == "RecursiveLazyProxy"
assert type(operator.attrgetter("name.foo.bar.baz")(proxy)).__name__ == "RecursiveLazyProxy"
+
+
+def test_isinstance_does_not_error() -> None:
+ class AlwaysErrorProxy(LazyProxy[Any]):
+ @override
+ def __load__(self) -> Any:
+ raise RuntimeError("Mocking missing dependency")
+
+ proxy = AlwaysErrorProxy()
+ assert not isinstance(proxy, dict)
+ assert isinstance(proxy, LazyProxy)
diff --git a/tests/utils.py b/tests/utils.py
index fa696b4a..5dbc46dc 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -4,7 +4,7 @@
import inspect
import traceback
import contextlib
-from typing import Any, TypeVar, Iterator, cast
+from typing import Any, TypeVar, Iterator, Sequence, cast
from datetime import date, datetime
from typing_extensions import Literal, get_args, get_origin, assert_type
@@ -15,10 +15,11 @@
is_list_type,
is_union_type,
extract_type_arg,
+ is_sequence_type,
is_annotated_type,
is_type_alias_type,
)
-from brainbase._compat import PYDANTIC_V2, field_outer_type, get_model_fields
+from brainbase._compat import PYDANTIC_V1, field_outer_type, get_model_fields
from brainbase._models import BaseModel
BaseModelT = TypeVar("BaseModelT", bound=BaseModel)
@@ -27,12 +28,12 @@
def assert_matches_model(model: type[BaseModelT], value: BaseModelT, *, path: list[str]) -> bool:
for name, field in get_model_fields(model).items():
field_value = getattr(value, name)
- if PYDANTIC_V2:
- allow_none = False
- else:
+ if PYDANTIC_V1:
# in v1 nullability was structured differently
# https://docs.pydantic.dev/2.0/migration/#required-optional-and-nullable-fields
allow_none = getattr(field, "allow_none", False)
+ else:
+ allow_none = False
assert_matches_type(
field_outer_type(field),
@@ -71,6 +72,13 @@ def assert_matches_type(
if is_list_type(type_):
return _assert_list_type(type_, value)
+ if is_sequence_type(type_):
+ assert isinstance(value, Sequence)
+ inner_type = get_args(type_)[0]
+ for entry in value: # type: ignore
+ assert_type(inner_type, entry) # type: ignore
+ return
+
if origin == str:
assert isinstance(value, str)
elif origin == int: