From eecd8a9037b1889e1359d76249f279112955a95d Mon Sep 17 00:00:00 2001 From: Boris Smidt Date: Mon, 1 Jun 2026 11:40:17 +0200 Subject: [PATCH 1/3] Make the async code match the sync code, keep the asyncIterator so in the future we could reintroduce stream as a special version of many. Fixes: https://github.com/sqlc-dev/sqlc-gen-python/issues/101 --- .../testdata/emit_pydantic_models/db/models.py | 2 +- .../testdata/emit_pydantic_models/db/query.py | 6 +++--- internal/endtoend/testdata/emit_str_enum/db/models.py | 2 +- internal/endtoend/testdata/emit_str_enum/db/query.py | 6 +++--- .../endtoend/testdata/exec_result/python/models.py | 2 +- .../endtoend/testdata/exec_result/python/query.py | 2 +- internal/endtoend/testdata/exec_rows/python/models.py | 2 +- internal/endtoend/testdata/exec_rows/python/query.py | 2 +- .../inflection_exclude_table_names/python/models.py | 2 +- .../inflection_exclude_table_names/python/query.py | 2 +- .../query_parameter_limit_two/python/models.py | 2 +- .../query_parameter_limit_two/python/query.py | 2 +- .../query_parameter_limit_undefined/python/models.py | 2 +- .../query_parameter_limit_undefined/python/query.py | 2 +- .../query_parameter_limit_zero/python/models.py | 2 +- .../query_parameter_limit_zero/python/query.py | 2 +- internal/gen.go | 11 +++++++---- 17 files changed, 27 insertions(+), 24 deletions(-) diff --git a/internal/endtoend/testdata/emit_pydantic_models/db/models.py b/internal/endtoend/testdata/emit_pydantic_models/db/models.py index 7676e5c..69c6835 100644 --- a/internal/endtoend/testdata/emit_pydantic_models/db/models.py +++ b/internal/endtoend/testdata/emit_pydantic_models/db/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.31.1 import pydantic from typing import Optional diff --git a/internal/endtoend/testdata/emit_pydantic_models/db/query.py b/internal/endtoend/testdata/emit_pydantic_models/db/query.py index 6f5b76f..86fd5c9 100644 --- a/internal/endtoend/testdata/emit_pydantic_models/db/query.py +++ b/internal/endtoend/testdata/emit_pydantic_models/db/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.31.1 # source: query.sql from typing import AsyncIterator, Iterator, Optional @@ -103,8 +103,8 @@ async def get_author(self, *, id: int) -> Optional[models.Author]: ) async def list_authors(self) -> AsyncIterator[models.Author]: - result = await self._conn.stream(sqlalchemy.text(LIST_AUTHORS)) - async for row in result: + rows = (await self._conn.execute(sqlalchemy.text(LIST_AUTHORS))).all() + for row in rows: yield models.Author( id=row[0], name=row[1], diff --git a/internal/endtoend/testdata/emit_str_enum/db/models.py b/internal/endtoend/testdata/emit_str_enum/db/models.py index 5fdf754..a18335d 100644 --- a/internal/endtoend/testdata/emit_str_enum/db/models.py +++ b/internal/endtoend/testdata/emit_str_enum/db/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.31.1 import dataclasses import enum from typing import Optional diff --git a/internal/endtoend/testdata/emit_str_enum/db/query.py b/internal/endtoend/testdata/emit_str_enum/db/query.py index 8082889..604bd66 100644 --- a/internal/endtoend/testdata/emit_str_enum/db/query.py +++ b/internal/endtoend/testdata/emit_str_enum/db/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.31.1 # source: query.sql from typing import AsyncIterator, Iterator, Optional @@ -102,8 +102,8 @@ async def get_book(self, *, id: int) -> Optional[models.Book]: ) async def list_books(self) -> AsyncIterator[models.Book]: - result = await self._conn.stream(sqlalchemy.text(LIST_BOOKS)) - async for row in result: + rows = (await self._conn.execute(sqlalchemy.text(LIST_BOOKS))).all() + for row in rows: yield models.Book( id=row[0], title=row[1], diff --git a/internal/endtoend/testdata/exec_result/python/models.py b/internal/endtoend/testdata/exec_result/python/models.py index 034fb2d..f4a55a5 100644 --- a/internal/endtoend/testdata/exec_result/python/models.py +++ b/internal/endtoend/testdata/exec_result/python/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.31.1 import dataclasses diff --git a/internal/endtoend/testdata/exec_result/python/query.py b/internal/endtoend/testdata/exec_result/python/query.py index b68ce39..b86bf48 100644 --- a/internal/endtoend/testdata/exec_result/python/query.py +++ b/internal/endtoend/testdata/exec_result/python/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.31.1 # source: query.sql import sqlalchemy import sqlalchemy.ext.asyncio diff --git a/internal/endtoend/testdata/exec_rows/python/models.py b/internal/endtoend/testdata/exec_rows/python/models.py index 034fb2d..f4a55a5 100644 --- a/internal/endtoend/testdata/exec_rows/python/models.py +++ b/internal/endtoend/testdata/exec_rows/python/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.31.1 import dataclasses diff --git a/internal/endtoend/testdata/exec_rows/python/query.py b/internal/endtoend/testdata/exec_rows/python/query.py index 7a9b2a6..cd0c464 100644 --- a/internal/endtoend/testdata/exec_rows/python/query.py +++ b/internal/endtoend/testdata/exec_rows/python/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.31.1 # source: query.sql import sqlalchemy import sqlalchemy.ext.asyncio diff --git a/internal/endtoend/testdata/inflection_exclude_table_names/python/models.py b/internal/endtoend/testdata/inflection_exclude_table_names/python/models.py index 8ba8803..f238065 100644 --- a/internal/endtoend/testdata/inflection_exclude_table_names/python/models.py +++ b/internal/endtoend/testdata/inflection_exclude_table_names/python/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.31.1 import dataclasses diff --git a/internal/endtoend/testdata/inflection_exclude_table_names/python/query.py b/internal/endtoend/testdata/inflection_exclude_table_names/python/query.py index 1e1e161..09edceb 100644 --- a/internal/endtoend/testdata/inflection_exclude_table_names/python/query.py +++ b/internal/endtoend/testdata/inflection_exclude_table_names/python/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.31.1 # source: query.sql from typing import Optional diff --git a/internal/endtoend/testdata/query_parameter_limit_two/python/models.py b/internal/endtoend/testdata/query_parameter_limit_two/python/models.py index 059675d..c02ab26 100644 --- a/internal/endtoend/testdata/query_parameter_limit_two/python/models.py +++ b/internal/endtoend/testdata/query_parameter_limit_two/python/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.31.1 import dataclasses diff --git a/internal/endtoend/testdata/query_parameter_limit_two/python/query.py b/internal/endtoend/testdata/query_parameter_limit_two/python/query.py index e8b723e..9e61b05 100644 --- a/internal/endtoend/testdata/query_parameter_limit_two/python/query.py +++ b/internal/endtoend/testdata/query_parameter_limit_two/python/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.31.1 # source: query.sql import sqlalchemy import sqlalchemy.ext.asyncio diff --git a/internal/endtoend/testdata/query_parameter_limit_undefined/python/models.py b/internal/endtoend/testdata/query_parameter_limit_undefined/python/models.py index 30e80db..f5fbb60 100644 --- a/internal/endtoend/testdata/query_parameter_limit_undefined/python/models.py +++ b/internal/endtoend/testdata/query_parameter_limit_undefined/python/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.31.1 import dataclasses diff --git a/internal/endtoend/testdata/query_parameter_limit_undefined/python/query.py b/internal/endtoend/testdata/query_parameter_limit_undefined/python/query.py index 5a1fbbc..ef54438 100644 --- a/internal/endtoend/testdata/query_parameter_limit_undefined/python/query.py +++ b/internal/endtoend/testdata/query_parameter_limit_undefined/python/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.31.1 # source: query.sql import sqlalchemy import sqlalchemy.ext.asyncio diff --git a/internal/endtoend/testdata/query_parameter_limit_zero/python/models.py b/internal/endtoend/testdata/query_parameter_limit_zero/python/models.py index 059675d..c02ab26 100644 --- a/internal/endtoend/testdata/query_parameter_limit_zero/python/models.py +++ b/internal/endtoend/testdata/query_parameter_limit_zero/python/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.31.1 import dataclasses diff --git a/internal/endtoend/testdata/query_parameter_limit_zero/python/query.py b/internal/endtoend/testdata/query_parameter_limit_zero/python/query.py index 47bd6a9..0b4257d 100644 --- a/internal/endtoend/testdata/query_parameter_limit_zero/python/query.py +++ b/internal/endtoend/testdata/query_parameter_limit_zero/python/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.31.1 # source: query.sql import dataclasses diff --git a/internal/gen.go b/internal/gen.go index 6e50fae..11616f3 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -1020,13 +1020,16 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node { ) f.Returns = subscriptNode("Optional", q.Ret.Annotation()) case ":many": - stream := connMethodNode("stream", q.ConstantName, q.ArgDictNode()) f.Body = append(f.Body, - assignNode("result", poet.Await(stream)), + assignNode("rows", poet.Node( + &pyast.Call{ + Func: poet.Attribute(poet.Await(exec), "all"), + }, + )), poet.Node( - &pyast.AsyncFor{ + &pyast.For{ Target: poet.Name("row"), - Iter: poet.Name("result"), + Iter: poet.Name("rows"), Body: []*pyast.Node{ poet.Expr( poet.Yield( From f14a9b718d6eb6e00029498f6c948069e0eadd07 Mon Sep 17 00:00:00 2001 From: Boris Smidt Date: Mon, 1 Jun 2026 11:40:55 +0200 Subject: [PATCH 2/3] Add .idea to the gitignore. --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 60b4f3d..baddcc4 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,5 @@ bin .direnv .devenv* devenv.local.nix + +.idea From 2c997e4fefb166cfdabb10b86253e9314d6173ab Mon Sep 17 00:00:00 2001 From: Boris Smidt Date: Mon, 15 Jun 2026 09:40:00 +0200 Subject: [PATCH 3/3] Add a type override function for domain types. --- README.md | 39 ++++++++++ internal/config.go | 6 ++ .../testdata/domain_overrides/db/models.py | 15 ++++ .../testdata/domain_overrides/db/query.py | 72 +++++++++++++++++++ .../testdata/domain_overrides/query.sql | 8 +++ .../testdata/domain_overrides/schema.sql | 19 +++++ .../testdata/domain_overrides/sqlc.yaml | 20 ++++++ .../testdata/emit_pydantic_models/sqlc.yaml | 2 +- .../endtoend/testdata/emit_str_enum/sqlc.yaml | 2 +- .../endtoend/testdata/exec_result/sqlc.yaml | 2 +- .../endtoend/testdata/exec_rows/sqlc.yaml | 2 +- .../inflection_exclude_table_names/sqlc.yaml | 2 +- .../query_parameter_limit_two/sqlc.yaml | 2 +- .../query_parameter_limit_undefined/sqlc.yaml | 2 +- .../query_parameter_limit_zero/sqlc.yaml | 2 +- .../query_parameter_no_limit/sqlc.yaml | 2 +- internal/gen.go | 43 +++++++---- internal/imports.go | 35 +++++++++ internal/postgresql_type.go | 34 ++++++++- 19 files changed, 287 insertions(+), 22 deletions(-) create mode 100644 internal/endtoend/testdata/domain_overrides/db/models.py create mode 100644 internal/endtoend/testdata/domain_overrides/db/query.py create mode 100644 internal/endtoend/testdata/domain_overrides/query.sql create mode 100644 internal/endtoend/testdata/domain_overrides/schema.sql create mode 100644 internal/endtoend/testdata/domain_overrides/sqlc.yaml diff --git a/README.md b/README.md index c9f2531..86d5017 100644 --- a/README.md +++ b/README.md @@ -76,3 +76,42 @@ class Status(str, enum.Enum): OPEN = "op!en" CLOSED = "clo@sed" ``` + +### Map domains (and other unknown types) to Python types + +Option: `domain_overrides` + +sqlc does not pass `CREATE DOMAIN` definitions (their base type or `CHECK` +constraints) to code generation plugins, so columns using a domain are emitted +as `Any` and a `unknown PostgreSQL type` warning is logged. The +`domain_overrides` option lets you map a PostgreSQL type name to a +fully-qualified Python type. The required `import` is added automatically, +including for nested modules. + +```yaml +options: + package: authors + domain_overrides: + job_status: my.module.JobStatus + positive_int: decimal.Decimal +``` + +Given a domain `job_status` used by a `status` column, this generates: + +```py +import decimal + +import my.module + + +@dataclasses.dataclass() +class Job: + id: int + status: my.module.JobStatus + priority: Optional[decimal.Decimal] +``` + +The key is matched against the column's data type, its bare type name, and its +schema-qualified name (e.g. `public.job_status`), so you can key the override +however is most convenient. This also works for any other type sqlc reports as +unknown, not just domains. diff --git a/internal/config.go b/internal/config.go index 1a8a565..7cb92ba 100644 --- a/internal/config.go +++ b/internal/config.go @@ -10,4 +10,10 @@ type Config struct { EmitStrEnum bool `json:"emit_str_enum"` QueryParameterLimit *int32 `json:"query_parameter_limit"` InflectionExcludeTableNames []string `json:"inflection_exclude_table_names"` + + // DomainOverrides maps a PostgreSQL type name (typically a DOMAIN, whose + // definition sqlc does not pass to plugins) to a fully-qualified Python + // type. For example {"job_status": "my.module.JobStatus"} emits a + // "import my.module" and annotates the column as "my.module.JobStatus". + DomainOverrides map[string]string `json:"domain_overrides"` } diff --git a/internal/endtoend/testdata/domain_overrides/db/models.py b/internal/endtoend/testdata/domain_overrides/db/models.py new file mode 100644 index 0000000..c348bc9 --- /dev/null +++ b/internal/endtoend/testdata/domain_overrides/db/models.py @@ -0,0 +1,15 @@ +# Code generated by sqlc. DO NOT EDIT. +# versions: +# sqlc v1.31.1 +import dataclasses +import decimal +from typing import Optional + +import my.module + + +@dataclasses.dataclass() +class Job: + id: int + status: my.module.JobStatus + priority: Optional[decimal.Decimal] diff --git a/internal/endtoend/testdata/domain_overrides/db/query.py b/internal/endtoend/testdata/domain_overrides/db/query.py new file mode 100644 index 0000000..7773908 --- /dev/null +++ b/internal/endtoend/testdata/domain_overrides/db/query.py @@ -0,0 +1,72 @@ +# Code generated by sqlc. DO NOT EDIT. +# versions: +# sqlc v1.31.1 +# source: query.sql +from typing import AsyncIterator, Iterator, Optional + +import my.module +import sqlalchemy +import sqlalchemy.ext.asyncio + +from db import models + + +GET_JOB = """-- name: get_job \\:one +SELECT id, status, priority FROM jobs +WHERE id = :p1 LIMIT 1 +""" + + +LIST_JOBS_BY_STATUS = """-- name: list_jobs_by_status \\:many +SELECT id, status, priority FROM jobs +WHERE status = :p1 +ORDER BY priority +""" + + +class Querier: + def __init__(self, conn: sqlalchemy.engine.Connection): + self._conn = conn + + def get_job(self, *, id: int) -> Optional[models.Job]: + row = self._conn.execute(sqlalchemy.text(GET_JOB), {"p1": id}).first() + if row is None: + return None + return models.Job( + id=row[0], + status=row[1], + priority=row[2], + ) + + def list_jobs_by_status(self, *, status: my.module.JobStatus) -> Iterator[models.Job]: + result = self._conn.execute(sqlalchemy.text(LIST_JOBS_BY_STATUS), {"p1": status}) + for row in result: + yield models.Job( + id=row[0], + status=row[1], + priority=row[2], + ) + + +class AsyncQuerier: + def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection): + self._conn = conn + + async def get_job(self, *, id: int) -> Optional[models.Job]: + row = (await self._conn.execute(sqlalchemy.text(GET_JOB), {"p1": id})).first() + if row is None: + return None + return models.Job( + id=row[0], + status=row[1], + priority=row[2], + ) + + async def list_jobs_by_status(self, *, status: my.module.JobStatus) -> AsyncIterator[models.Job]: + rows = (await self._conn.execute(sqlalchemy.text(LIST_JOBS_BY_STATUS), {"p1": status})).all() + for row in rows: + yield models.Job( + id=row[0], + status=row[1], + priority=row[2], + ) diff --git a/internal/endtoend/testdata/domain_overrides/query.sql b/internal/endtoend/testdata/domain_overrides/query.sql new file mode 100644 index 0000000..936746f --- /dev/null +++ b/internal/endtoend/testdata/domain_overrides/query.sql @@ -0,0 +1,8 @@ +-- name: GetJob :one +SELECT * FROM jobs +WHERE id = $1 LIMIT 1; + +-- name: ListJobsByStatus :many +SELECT * FROM jobs +WHERE status = $1 +ORDER BY priority; diff --git a/internal/endtoend/testdata/domain_overrides/schema.sql b/internal/endtoend/testdata/domain_overrides/schema.sql new file mode 100644 index 0000000..58fc282 --- /dev/null +++ b/internal/endtoend/testdata/domain_overrides/schema.sql @@ -0,0 +1,19 @@ +CREATE DOMAIN job_status AS text +CHECK ( + VALUE IN ( + 'QUEUED', + 'PENDING', + 'RUNNING', + 'COMPLETED', + 'FAILED' + )) NOT NULL; + +CREATE DOMAIN positive_int AS integer +CHECK (VALUE > 0); + + +CREATE TABLE jobs ( + id BIGSERIAL PRIMARY KEY, + status job_status NOT NULL, + priority positive_int +); diff --git a/internal/endtoend/testdata/domain_overrides/sqlc.yaml b/internal/endtoend/testdata/domain_overrides/sqlc.yaml new file mode 100644 index 0000000..05ef3e4 --- /dev/null +++ b/internal/endtoend/testdata/domain_overrides/sqlc.yaml @@ -0,0 +1,20 @@ +version: "2" +plugins: + - name: py + wasm: + url: file://../../../../bin/sqlc-gen-python.wasm + sha256: "00c7c16380c4593d7a86b82e2f650c5655c179cdb0d63b1513d6987ec9be0f46" +sql: + - schema: schema.sql + queries: query.sql + engine: postgresql + codegen: + - plugin: py + out: db + options: + package: db + emit_sync_querier: true + emit_async_querier: true + domain_overrides: + job_status: my.module.JobStatus + positive_int: decimal.Decimal diff --git a/internal/endtoend/testdata/emit_pydantic_models/sqlc.yaml b/internal/endtoend/testdata/emit_pydantic_models/sqlc.yaml index beae200..1e3ad44 100644 --- a/internal/endtoend/testdata/emit_pydantic_models/sqlc.yaml +++ b/internal/endtoend/testdata/emit_pydantic_models/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "00c7c16380c4593d7a86b82e2f650c5655c179cdb0d63b1513d6987ec9be0f46" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/emit_str_enum/sqlc.yaml b/internal/endtoend/testdata/emit_str_enum/sqlc.yaml index 04e3feb..00621aa 100644 --- a/internal/endtoend/testdata/emit_str_enum/sqlc.yaml +++ b/internal/endtoend/testdata/emit_str_enum/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "00c7c16380c4593d7a86b82e2f650c5655c179cdb0d63b1513d6987ec9be0f46" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/exec_result/sqlc.yaml b/internal/endtoend/testdata/exec_result/sqlc.yaml index ddffc83..d8f39ee 100644 --- a/internal/endtoend/testdata/exec_result/sqlc.yaml +++ b/internal/endtoend/testdata/exec_result/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "00c7c16380c4593d7a86b82e2f650c5655c179cdb0d63b1513d6987ec9be0f46" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/exec_rows/sqlc.yaml b/internal/endtoend/testdata/exec_rows/sqlc.yaml index ddffc83..d8f39ee 100644 --- a/internal/endtoend/testdata/exec_rows/sqlc.yaml +++ b/internal/endtoend/testdata/exec_rows/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "00c7c16380c4593d7a86b82e2f650c5655c179cdb0d63b1513d6987ec9be0f46" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/inflection_exclude_table_names/sqlc.yaml b/internal/endtoend/testdata/inflection_exclude_table_names/sqlc.yaml index efbb150..ba13e2f 100644 --- a/internal/endtoend/testdata/inflection_exclude_table_names/sqlc.yaml +++ b/internal/endtoend/testdata/inflection_exclude_table_names/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "00c7c16380c4593d7a86b82e2f650c5655c179cdb0d63b1513d6987ec9be0f46" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/query_parameter_limit_two/sqlc.yaml b/internal/endtoend/testdata/query_parameter_limit_two/sqlc.yaml index 336bca7..880c27a 100644 --- a/internal/endtoend/testdata/query_parameter_limit_two/sqlc.yaml +++ b/internal/endtoend/testdata/query_parameter_limit_two/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "00c7c16380c4593d7a86b82e2f650c5655c179cdb0d63b1513d6987ec9be0f46" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/query_parameter_limit_undefined/sqlc.yaml b/internal/endtoend/testdata/query_parameter_limit_undefined/sqlc.yaml index c20cd57..be465be 100644 --- a/internal/endtoend/testdata/query_parameter_limit_undefined/sqlc.yaml +++ b/internal/endtoend/testdata/query_parameter_limit_undefined/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "00c7c16380c4593d7a86b82e2f650c5655c179cdb0d63b1513d6987ec9be0f46" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/query_parameter_limit_zero/sqlc.yaml b/internal/endtoend/testdata/query_parameter_limit_zero/sqlc.yaml index 6e2cdeb..76df648 100644 --- a/internal/endtoend/testdata/query_parameter_limit_zero/sqlc.yaml +++ b/internal/endtoend/testdata/query_parameter_limit_zero/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "00c7c16380c4593d7a86b82e2f650c5655c179cdb0d63b1513d6987ec9be0f46" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/query_parameter_no_limit/sqlc.yaml b/internal/endtoend/testdata/query_parameter_no_limit/sqlc.yaml index c432e4f..27d6eb9 100644 --- a/internal/endtoend/testdata/query_parameter_no_limit/sqlc.yaml +++ b/internal/endtoend/testdata/query_parameter_no_limit/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "00c7c16380c4593d7a86b82e2f650c5655c179cdb0d63b1513d6987ec9be0f46" sql: - schema: schema.sql queries: query.sql diff --git a/internal/gen.go b/internal/gen.go index 11616f3..63596e1 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -180,8 +180,8 @@ func (q Query) ArgDictNode() *pyast.Node { } } -func makePyType(req *plugin.GenerateRequest, col *plugin.Column) pyType { - typ := pyInnerType(req, col) +func makePyType(conf Config, req *plugin.GenerateRequest, col *plugin.Column) pyType { + typ := pyInnerType(conf, req, col) return pyType{ InnerType: typ, IsArray: col.IsArray, @@ -189,10 +189,10 @@ func makePyType(req *plugin.GenerateRequest, col *plugin.Column) pyType { } } -func pyInnerType(req *plugin.GenerateRequest, col *plugin.Column) string { +func pyInnerType(conf Config, req *plugin.GenerateRequest, col *plugin.Column) string { switch req.Settings.Engine { case "postgresql": - return postgresType(req, col) + return postgresType(conf, req, col) default: log.Println("unsupported engine type") return "Any" @@ -285,7 +285,7 @@ func buildModels(conf Config, req *plugin.GenerateRequest) []Struct { Comment: table.Comment, } for _, column := range table.Columns { - typ := makePyType(req, column) // TODO: This used to call compiler.ConvertColumn? + typ := makePyType(conf, req, column) // TODO: This used to call compiler.ConvertColumn? typ.InnerType = strings.TrimPrefix(typ.InnerType, "models.") s.Fields = append(s.Fields, Field{ Name: column.Name, @@ -321,7 +321,7 @@ type pyColumn struct { *plugin.Column } -func columnsToStruct(req *plugin.GenerateRequest, name string, columns []pyColumn) *Struct { +func columnsToStruct(conf Config, req *plugin.GenerateRequest, name string, columns []pyColumn) *Struct { gs := Struct{ Name: name, } @@ -344,7 +344,7 @@ func columnsToStruct(req *plugin.GenerateRequest, name string, columns []pyColum } gs.Fields = append(gs.Fields, Field{ Name: fieldName, - Type: makePyType(req, c.Column), + Type: makePyType(conf, req, c.Column), }) seen[colName]++ } @@ -406,14 +406,14 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([ gq.Args = []QueryValue{{ Emit: true, Name: "arg", - Struct: columnsToStruct(req, query.Name+"Params", cols), + Struct: columnsToStruct(conf, req, query.Name+"Params", cols), }} } else { args := make([]QueryValue, 0, len(query.Params)) for _, p := range query.Params { args = append(args, QueryValue{ Name: paramName(p), - Typ: makePyType(req, p.Column), + Typ: makePyType(conf, req, p.Column), }) } gq.Args = args @@ -423,7 +423,7 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([ c := query.Columns[0] gq.Ret = QueryValue{ Name: columnName(c, 0), - Typ: makePyType(req, c), + Typ: makePyType(conf, req, c), } } else if len(query.Columns) > 1 { var gs *Struct @@ -438,7 +438,7 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([ for i, f := range s.Fields { c := query.Columns[i] // HACK: models do not have "models." on their types, so trim that so we can find matches - trimmedPyType := makePyType(req, c) + trimmedPyType := makePyType(conf, req, c) trimmedPyType.InnerType = strings.TrimPrefix(trimmedPyType.InnerType, "models.") sameName := f.Name == columnName(c, i) sameType := f.Type == trimmedPyType @@ -461,7 +461,7 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([ Column: c, }) } - gs = columnsToStruct(req, query.Name+"Row", columns) + gs = columnsToStruct(conf, req, query.Name+"Row", columns) emit = true } gq.Ret = QueryValue{ @@ -1041,6 +1041,25 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node { ), ) f.Returns = subscriptNode("AsyncIterator", q.Ret.Annotation()) + case ":stream": + stream := connMethodNode("stream", q.ConstantName, q.ArgDictNode()) + f.Body = append(f.Body, + assignNode("result", poet.Await(stream)), + poet.Node( + &pyast.AsyncFor{ + Target: poet.Name("row"), + Iter: poet.Name("result"), + Body: []*pyast.Node{ + poet.Expr( + poet.Yield( + q.Ret.RowNode("row"), + ), + ), + }, + }, + ), + ) + f.Returns = subscriptNode("AsyncIterator", q.Ret.Annotation()) case ":exec": f.Body = append(f.Body, poet.Await(exec)) f.Returns = poet.Constant(nil) diff --git a/internal/imports.go b/internal/imports.go index b88c58c..ac5c9ed 100644 --- a/internal/imports.go +++ b/internal/imports.go @@ -96,6 +96,7 @@ func (i *importer) modelImportSpecs() (map[string]importSpec, map[string]importS } pkg := make(map[string]importSpec) + addOverrideImports(pkg, std, i.C, modelUses) return std, pkg } @@ -135,6 +136,7 @@ func (i *importer) queryImportSpecs(fileName string) (map[string]importSpec, map if i.C.EmitAsyncQuerier { pkg["sqlalchemy.ext.asyncio"] = importSpec{Module: "sqlalchemy.ext.asyncio"} } + addOverrideImports(pkg, std, i.C, queryUses) queryValueModelImports := func(qv QueryValue) { if qv.IsStruct() && qv.EmitStruct() { @@ -252,6 +254,39 @@ func buildImportBlock(pkgs map[string]importSpec) string { return strings.Join(importStrings, "\n") } +// addOverrideImports adds an "import " for every configured domain +// override whose Python type is actually referenced. For a value like +// "my.module.JobStatus" it imports "my.module" (the segment before the final +// dot). Bare names without a module are left untouched. +func addOverrideImports(pkg, std map[string]importSpec, conf Config, uses func(name string) bool) { + for _, pyType := range conf.DomainOverrides { + if !uses(pyType) { + continue + } + module := moduleOf(pyType) + if module == "" || moduleImported(std, module) || moduleImported(pkg, module) { + continue + } + pkg[module] = importSpec{Module: module} + } +} + +func moduleImported(specs map[string]importSpec, module string) bool { + for _, s := range specs { + if s.Module == module { + return true + } + } + return false +} + +func moduleOf(pyType string) string { + if idx := strings.LastIndex(pyType, "."); idx >= 0 { + return pyType[:idx] + } + return "" +} + func stdImports(uses func(name string) bool) map[string]importSpec { std := make(map[string]importSpec) if uses("decimal.Decimal") { diff --git a/internal/postgresql_type.go b/internal/postgresql_type.go index 3d0891b..8759472 100644 --- a/internal/postgresql_type.go +++ b/internal/postgresql_type.go @@ -7,9 +7,16 @@ import ( "github.com/sqlc-dev/plugin-sdk-go/sdk" ) -func postgresType(req *plugin.GenerateRequest, col *plugin.Column) string { +func postgresType(conf Config, req *plugin.GenerateRequest, col *plugin.Column) string { columnType := sdk.DataType(col.Type) + // User-configured overrides take precedence. This is the only way to map + // PostgreSQL DOMAINs (and other types sqlc does not describe to plugins) to + // a concrete Python type, since their definitions never reach the plugin. + if py, ok := domainOverride(conf, col, columnType); ok { + return py + } + switch columnType { case "serial", "serial4", "pg_catalog.serial4", "bigserial", "serial8", "pg_catalog.serial8", "smallserial", "serial2", "pg_catalog.serial2", "integer", "int", "int4", "pg_catalog.int4", "bigint", "int8", "pg_catalog.int8", "smallint", "int2", "pg_catalog.int2": return "int" @@ -60,3 +67,28 @@ func postgresType(req *plugin.GenerateRequest, col *plugin.Column) string { return "Any" } } + +// domainOverride looks up a configured Python type for the given column type. +// It accepts the plain data-type string (e.g. "job_status"), the bare type +// name, and the schema-qualified name (e.g. "public.job_status"), so users can +// key the override however is most convenient. +func domainOverride(conf Config, col *plugin.Column, columnType string) (string, bool) { + if len(conf.DomainOverrides) == 0 { + return "", false + } + candidates := []string{columnType} + if col.Type != nil { + if col.Type.Name != "" { + candidates = append(candidates, col.Type.Name) + } + if col.Type.Schema != "" && col.Type.Name != "" { + candidates = append(candidates, col.Type.Schema+"."+col.Type.Name) + } + } + for _, c := range candidates { + if py, ok := conf.DomainOverrides[c]; ok { + return py, true + } + } + return "", false +}