-
Notifications
You must be signed in to change notification settings - Fork 18
feat: Add optional sqlalchemy_use_enum for dy.Enum #355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,6 +33,8 @@ def __init__( | |
| alias: str | None = None, | ||
| metadata: dict[str, Any] | None = None, | ||
| description: str | None = None, | ||
| sqlalchemy_use_enum: bool = False, | ||
| sqlalchemy_enum_name: str | None = None, | ||
| ): | ||
| """ | ||
| Args: | ||
|
|
@@ -68,6 +70,17 @@ def __init__( | |
| names, the specified alias is the only valid name. | ||
| metadata: A dictionary of metadata to attach to the column. | ||
| description: A human-readable description of the column. | ||
| sqlalchemy_use_enum: When ``True``, map this column to :class:`sqlalchemy.Enum` | ||
| in :meth:`~dataframely.Schema.to_sqlalchemy_columns` instead of | ||
| ``CHAR`` / ``VARCHAR``. Use this for PostgreSQL native enum types and | ||
| Alembic schema drift detection. Defaults to ``False`` (string columns). | ||
| sqlalchemy_enum_name: Optional name for the SQLAlchemy / database enum type | ||
| when ``sqlalchemy_use_enum=True``. If omitted and ``categories`` is a | ||
| Python :class:`enum.Enum` subclass, SQLAlchemy uses the enum class name | ||
| (lowercased). Otherwise the SQL column name from | ||
| :meth:`~dataframely.Schema.to_sqlalchemy_columns` is used. For Python | ||
| enums, persisted values are the enum members' ``.value`` strings (not | ||
| member names), matching :attr:`categories`. | ||
| """ | ||
| super().__init__( | ||
| nullable=nullable, | ||
|
|
@@ -78,7 +91,11 @@ def __init__( | |
| metadata=metadata, | ||
| description=description, | ||
| ) | ||
| self.sqlalchemy_use_enum = sqlalchemy_use_enum | ||
| self.sqlalchemy_enum_name = sqlalchemy_enum_name | ||
| self._enum_class: type[enum.Enum] | None = None | ||
| if isclass(categories) and issubclass(categories, enum.Enum): | ||
| self._enum_class = categories | ||
| categories = (item.value for item in categories) | ||
| self.categories = list(categories) | ||
|
|
||
|
|
@@ -91,12 +108,45 @@ def validate_dtype(self, dtype: PolarsDataType) -> bool: | |
| return False | ||
| return self.categories == dtype.categories.to_list() | ||
|
|
||
| def sqlalchemy_column(self, name: str, dialect: sa.Dialect) -> sa.Column: | ||
| if self.sqlalchemy_use_enum: | ||
| column = super().sqlalchemy_column(name, dialect) | ||
| column.type = self._sqlalchemy_enum_type(dialect, column_name=name) | ||
| return column | ||
|
Comment on lines
+112
to
+115
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be good to add a comment here to say why we need to override this method from the parant at all. I assume it's because we want to pass |
||
| return super().sqlalchemy_column(name, dialect) | ||
|
Comment on lines
+111
to
+116
|
||
|
|
||
| def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: | ||
| if self.sqlalchemy_use_enum: | ||
| column_name = self._name or None | ||
| return self._sqlalchemy_enum_type(dialect, column_name=column_name) | ||
| category_lengths = [len(c) for c in self.categories] | ||
| if all(length == category_lengths[0] for length in category_lengths): | ||
| return sa.CHAR(category_lengths[0]) | ||
| return sa.String(max(category_lengths)) | ||
|
|
||
| def _sqlalchemy_enum_type( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is taking me a surprisingly long time to mentally parse this method. Do you think you could add some more comments and possibly structure the code to make it easier to understand in case the reader is less familiar with
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For example, is it possible to factor this into a |
||
| self, _dialect: sa.Dialect, *, column_name: str | None | ||
| ) -> sa_TypeEngine: | ||
| length = max(len(c) for c in self.categories) | ||
| kwargs: dict[str, Any] = {"length": length} | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Concrete example: Do we need the |
||
| name = self.sqlalchemy_enum_name | ||
| if self._enum_class is not None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would definitely be useful to add a comment here to say what |
||
| if name is not None: | ||
| kwargs["name"] = name | ||
| kwargs["values_callable"] = lambda enum: [member.value for member in enum] | ||
| return sa.Enum(self._enum_class, **kwargs) | ||
|
AndreasAlbertQC marked this conversation as resolved.
|
||
| if name is None: | ||
| name = column_name | ||
| if name is None: | ||
| raise ValueError( | ||
| "sqlalchemy_enum_name is required for dy.Enum with string categories " | ||
| "and sqlalchemy_use_enum=True when not building columns via " | ||
| "Schema.to_sqlalchemy_columns(). Alternatively, pass a Python " | ||
| "enum.Enum class as categories." | ||
| ) | ||
|
Comment on lines
+142
to
+146
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this already be checked in the constructor (early) rather than here (late)? |
||
| kwargs["name"] = name | ||
| return sa.Enum(*self.categories, **kwargs) | ||
|
|
||
| @property | ||
| def pyarrow_dtype(self) -> pa.DataType: | ||
| if len(self.categories) <= 2**8 - 1: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -81,6 +81,27 @@ the maximal length of the string is inferred from the regular expression if poss | |
| maximal lengths can be particularly important for primary key columns. Some database systems, such as Microsoft SQL Server, do not allow `VARCHAR(max)` columns (unbounded strings) to be used as primary keys. | ||
| ``` | ||
|
|
||
| ## Native SQL enums (optional) | ||
|
|
||
| By default, {class}`~dataframely.Enum` maps to fixed-length `CHAR` or `VARCHAR` columns so stored values remain plain strings. For PostgreSQL setups that use database-level `ENUM` types (for example with Alembic autogenerate), set `sqlalchemy_use_enum=True`: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of saying |
||
|
|
||
| ```python | ||
| from enum import Enum | ||
|
|
||
| import dataframely as dy | ||
|
|
||
|
|
||
| class Status(str, Enum): | ||
| PENDING = "pending" | ||
| APPROVED = "approved" | ||
|
|
||
|
|
||
| class Staged(dy.Schema): | ||
| status = dy.Enum(Status, sqlalchemy_use_enum=True) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it make sense to expand this example by showing what |
||
| ``` | ||
|
|
||
| When `categories` is a Python `enum.Enum` subclass, SQLAlchemy uses the enum class name (lowercased) as the database enum type name. For string category lists, the SQL column name is used by default; override it with `sqlalchemy_enum_name` if needed. On dialects without native enums (such as Microsoft SQL Server), SQLAlchemy falls back to `VARCHAR` with a check constraint. | ||
|
|
||
| ## Collections of multiple tables | ||
|
|
||
| If you have an entire `dy.Collection`, it's also easy to generate one table for each member table of the collection. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -108,3 +108,43 @@ def test_sequences_and_enums( | |
| S = create_schema("test", {"x": dy.Enum(categories1)}) | ||
| df = pl.DataFrame({"x": pl.Series(["a", "b"], dtype=pl.Enum(categories2))}) | ||
| S.validate(df) | ||
|
|
||
|
|
||
| def test_matches_sqlalchemy_use_enum() -> None: | ||
| expr = pl.element() | ||
| assert dy.Enum(["a", "b"]).matches(dy.Enum(["a", "b"]), expr) | ||
| assert not dy.Enum(["a", "b"], sqlalchemy_use_enum=True).matches( | ||
| dy.Enum(["a", "b"]), expr | ||
| ) | ||
| assert dy.Enum(["a", "b"], sqlalchemy_use_enum=True).matches( | ||
| dy.Enum(["a", "b"], sqlalchemy_use_enum=True), expr | ||
| ) | ||
|
|
||
|
|
||
| def test_matches_sqlalchemy_enum_name() -> None: | ||
| expr = pl.element() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we missing the positive case here? |
||
| assert not dy.Enum( | ||
| ["a", "b"], | ||
| sqlalchemy_use_enum=True, | ||
| sqlalchemy_enum_name="one", | ||
| ).matches( | ||
| dy.Enum( | ||
| ["a", "b"], | ||
| sqlalchemy_use_enum=True, | ||
| sqlalchemy_enum_name="two", | ||
| ), | ||
| expr, | ||
| ) | ||
|
|
||
|
|
||
| def test_as_dict_from_dict_sqlalchemy_enum_flags() -> None: | ||
| column = dy.Enum( | ||
| ["a", "b"], | ||
| sqlalchemy_use_enum=True, | ||
| sqlalchemy_enum_name="my_enum", | ||
| ) | ||
| data = column.as_dict(pl.element()) | ||
| restored = dy.Enum.from_dict(data) | ||
| assert restored.sqlalchemy_use_enum is True | ||
| assert restored.sqlalchemy_enum_name == "my_enum" | ||
| assert restored.categories == ["a", "b"] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,8 @@ | ||
| # Copyright (c) QuantCo 2025-2026 | ||
| # SPDX-License-Identifier: BSD-3-Clause | ||
|
|
||
| from enum import Enum | ||
|
|
||
| import pytest | ||
|
|
||
| import dataframely as dy | ||
|
|
@@ -171,3 +173,64 @@ def test_raise_for_object_column(dialect: Dialect) -> None: | |
| NotImplementedError, match="SQL column cannot have 'Object' type." | ||
| ): | ||
| dy.Object().sqlalchemy_dtype(dialect) | ||
|
|
||
|
|
||
| class _Status(str, Enum): | ||
| PENDING = "pending" | ||
| APPROVED = "approved" | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| ("column", "dialect", "datatype"), | ||
| [ | ||
|
Comment on lines
+178
to
+185
|
||
| ( | ||
| dy.Enum(["foo", "bar"], sqlalchemy_use_enum=True), | ||
| PGDialect_psycopg2(), | ||
| "a", | ||
| ), | ||
| ( | ||
| dy.Enum( | ||
| ["foo", "bar"], | ||
| sqlalchemy_use_enum=True, | ||
| sqlalchemy_enum_name="my_status", | ||
| ), | ||
| PGDialect_psycopg2(), | ||
| "my_status", | ||
| ), | ||
| (dy.Enum(_Status, sqlalchemy_use_enum=True), PGDialect_psycopg2(), "_status"), | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a valid point, but covered by |
||
| ( | ||
| dy.Enum(["foo", "bar"], sqlalchemy_use_enum=True), | ||
| MSDialect_pyodbc(), | ||
| "VARCHAR(3)", | ||
| ), | ||
| ], | ||
| ) | ||
| def test_enum_sqlalchemy_native(column: Column, dialect: Dialect, datatype: str) -> None: | ||
| schema = create_schema("test", {"a": column}) | ||
| columns = schema.to_sqlalchemy_columns(dialect) | ||
| assert len(columns) == 1 | ||
| assert columns[0].type.compile(dialect) == datatype | ||
|
|
||
|
|
||
| def test_enum_sqlalchemy_native_python_enum_uses_member_values() -> None: | ||
| column = dy.Enum(_Status, sqlalchemy_use_enum=True) | ||
| schema = create_schema("test", {"a": column}) | ||
| sa_type = schema.to_sqlalchemy_columns(PGDialect_psycopg2())[0].type | ||
| assert list(sa_type.enums) == column.categories | ||
|
|
||
|
|
||
| def test_enum_sqlalchemy_native_string_categories_use_column_name() -> None: | ||
| class TestSchema(dy.Schema): | ||
| status = dy.Enum(["foo", "bar"], sqlalchemy_use_enum=True) | ||
|
|
||
| column = TestSchema.columns()["status"] | ||
| assert column.sqlalchemy_dtype(PGDialect_psycopg2()).compile( | ||
| PGDialect_psycopg2() | ||
| ) == "status" | ||
|
|
||
|
|
||
| def test_enum_sqlalchemy_native_string_categories_requires_name_without_column( | ||
| ) -> None: | ||
| column = dy.Enum(["foo", "bar"], sqlalchemy_use_enum=True) | ||
| with pytest.raises(ValueError, match="sqlalchemy_enum_name is required"): | ||
| column.sqlalchemy_dtype(PGDialect_psycopg2()) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a check here and fail loudly if a user sets
sqlalchemy_enum_namebut notsqlalchemy_use_enum