Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 0 additions & 70 deletions scripts/generate-schema/gen-python.test.ts

This file was deleted.

56 changes: 25 additions & 31 deletions scripts/generate-schema/gen-python.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,20 @@ import { assert } from "jsr:@std/assert";
const header = (relativePath: string) =>
`# DO NOT EDIT: This file is auto-generated by ${relativePath}\n` +
"from enum import Enum\n" +
"from typing import Any, Literal, Optional, Union\n" +
"import msgspec\n\n";

export function extractExportedNames(content: string): string[] {
const names = new Set<string>();

// Match class definitions and enum assignments
// Match class definitions and union type assignments
const classRegex = /^class\s+([a-zA-Z_][a-zA-Z0-9_]*)/gm;
const enumRegex = /^([a-zA-Z_][a-zA-Z0-9_]*)\s*=\s*Union\[/gm;
const unionRegex = /^([a-zA-Z_][a-zA-Z0-9_]*)\s*=\s*.+\|.+/gm;

let match;
while ((match = classRegex.exec(content)) !== null) {
names.add(match[1]);
}
while ((match = enumRegex.exec(content)) !== null) {
while ((match = unionRegex.exec(content)) !== null) {
names.add(match[1]);
}

Expand Down Expand Up @@ -51,10 +50,7 @@ export function generatePython(
return output + content;
}

function generateTypes(
doc: Doc,
name: string,
) {
function generateTypes(doc: Doc, name: string) {
const writer = new Writer();

let definitions = "";
Expand Down Expand Up @@ -97,9 +93,7 @@ function generateTypes(
return definitions + writer.output();
}

function sortByRequired<T extends { required: boolean }>(
properties: T[],
): T[] {
function sortByRequired<T extends { required: boolean }>(properties: T[]): T[] {
return [...properties].sort((a, b) => {
if (a.required === b.required) return 0;
return a.required ? -1 : 1;
Expand All @@ -116,9 +110,8 @@ function generateNode(node: Node, writer: Writer) {
.with({ type: "boolean" }, () => w("bool"))
.with({ type: "string" }, () => w("str"))
.with({ type: "literal" }, (node) => w(`Literal["${node.value}"]`))
.with(
{ type: "record" },
(node) => w(`dict[str, ${mapPythonType(node.valueType)}]`),
.with({ type: "record" }, (node) =>
w(`dict[str, ${mapPythonType(node.valueType)}]`),
)
.with({ type: "enum" }, (node) => {
wn(`class ${node.name}(str, Enum):`);
Expand All @@ -134,9 +127,10 @@ function generateNode(node: Node, writer: Writer) {
if (m.name) {
name = m.name;
} else {
const ident = m.type === "object"
? m.properties?.find((p) => p.required)?.key ?? ""
: "";
const ident =
m.type === "object"
? (m.properties?.find((p) => p.required)?.key ?? "")
: "";
name = `${node.name}${cap(ident)}`;
}
if (!generatedDependentClasses.has(name)) {
Expand All @@ -148,17 +142,17 @@ function generateNode(node: Node, writer: Writer) {
return name;
});
writer.append(depWriter.output());
wn(`${node.name} = Union[${classes.join(", ")}]`);
wn(`${node.name} = ${classes.join(" | ")}`);
})
.with({ type: "object" }, (node) => {
match(context.parent)
.with({ type: "union" }, () => {
const name = context.closestName();
const ident = node.properties.find((p) => p.required)?.key ?? "";
wn(
`class ${name}${
cap(ident)
}(msgspec.Struct, kw_only=True, omit_defaults=True):`,
`class ${name}${cap(
ident,
)}(msgspec.Struct, kw_only=True, omit_defaults=True):`,
);
})
.with(P.nullish, () => {
Expand All @@ -179,9 +173,8 @@ function generateNode(node: Node, writer: Writer) {

for (const { key, required, description, value } of sortedProperties) {
w(` ${key}: `);
if (!required) w("Union[");
generateNode(value, writer);
if (!required) w(", None] = None");
if (!required) w(" | None = None");
wn("");
if (description) {
wn(` """${description}"""`);
Expand All @@ -193,7 +186,6 @@ function generateNode(node: Node, writer: Writer) {
const depWriter = new Writer();
const { w: d, wn: dn } = depWriter.shorthand();
const classes: string[] = [];
w("Union[");
for (const [name, properties] of Object.entries(node.members)) {
for (const { value } of properties) {
if (isComplexType(value)) {
Expand All @@ -213,15 +205,17 @@ function generateNode(node: Node, writer: Writer) {

const sortedProperties = sortByRequired(properties);

for (
const { key, required, description, value } of sortedProperties
) {
for (const {
key,
required,
description,
value,
} of sortedProperties) {
d(` ${key}: `);
if (!required) d("Union[");
!isComplexType(value)
? generateNode(value, depWriter)
: d(value.name ?? value.type);
if (!required) d(", None] = None");
if (!required) d(" | None = None");
dn("");
if (description) {
dn(` """${description}"""`);
Expand All @@ -230,9 +224,9 @@ function generateNode(node: Node, writer: Writer) {
dn("");
}
}
w(classes.join(", "));
w(classes.join(" | "));
writer.prepend(depWriter.output());
wn("]");
wn("");
})
.with({ type: "intersection" }, (node) => {
assert(
Expand Down
14 changes: 7 additions & 7 deletions src/clients/python/src/justbe_webview/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import platform
import subprocess
from typing import Any, Callable, Literal, Union, cast, TypeVar
from typing import Any, Callable, Literal, cast, TypeVar
from pathlib import Path
import aiofiles
import httpx
Expand Down Expand Up @@ -54,7 +54,7 @@


def return_result(
result: Union[AckResponse, ResultResponse, ErrResponse],
result: AckResponse | ResultResponse | ErrResponse,
expected_type: type[ResultType],
) -> Any:
print(f"Return result: {result}")
Expand All @@ -63,7 +63,7 @@ def return_result(
raise ValueError(f"Expected {expected_type.__name__} result got: {result}")


def return_ack(result: Union[AckResponse, ResultResponse, ErrResponse]) -> None:
def return_ack(result: AckResponse | ResultResponse | ErrResponse) -> None:
print(f"Return ack: {result}")
if isinstance(result, AckResponse):
return
Expand Down Expand Up @@ -212,11 +212,11 @@ async def __aexit__(
async def send(self, request: WebViewRequest) -> WebViewResponse:
if self.process is None:
raise RuntimeError("Webview process not started")
future: asyncio.Future[Union[AckResponse, ResultResponse, ErrResponse]] = (
future: asyncio.Future[AckResponse | ResultResponse | ErrResponse] = (
asyncio.Future()
)

def set_result(event: Union[AckResponse, ResultResponse, ErrResponse]) -> None:
def set_result(event: AckResponse | ResultResponse | ErrResponse) -> None:
future.set_result(event)

self.internal_event.once(str(request.id), set_result) # type: ignore
Expand All @@ -229,7 +229,7 @@ def set_result(event: Union[AckResponse, ResultResponse, ErrResponse]) -> None:
result = await future
return result

async def recv(self) -> Union[WebViewNotification, None]:
async def recv(self) -> WebViewNotification | None:
if self.process is None:
raise RuntimeError("Webview process not started")

Expand Down Expand Up @@ -319,7 +319,7 @@ async def set_size(self, size: dict[Literal["width", "height"], float]):

async def get_size(
self, include_decorations: bool = False
) -> dict[Literal["width", "height", "scaleFactor"], Union[int, float]]:
) -> dict[Literal["width", "height", "scaleFactor"], int | float]:
result = await self.send(
GetSizeRequest(id=self.message_id, include_decorations=include_decorations)
)
Expand Down
Loading