Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions dataframely/columns/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def __init__(
alias: str | None = None,
metadata: dict[str, Any] | None = None,
description: str | None = None,
sqlalchemy_use_enum: bool = False,
sqlalchemy_enum_name: str | None = None,
):
"""
Args:
Expand Down Expand Up @@ -68,6 +70,15 @@ def __init__(
names, the specified alias is the only valid name.
metadata: A dictionary of metadata to attach to the column.
description: A human-readable description of the column.
sqlalchemy_use_enum: When ``True``, map this column to :class:`sqlalchemy.Enum`
in :meth:`~dataframely.Schema.to_sqlalchemy_columns` instead of
``CHAR`` / ``VARCHAR``. Use this for PostgreSQL native enum types and
Alembic schema drift detection. Defaults to ``False`` (string columns).
sqlalchemy_enum_name: Optional name for the SQLAlchemy / database enum type
when ``sqlalchemy_use_enum=True``. If omitted and ``categories`` is a
Python :class:`enum.Enum` subclass, SQLAlchemy uses the enum class name
(lowercased). Otherwise the SQL column name from
:meth:`~dataframely.Schema.to_sqlalchemy_columns` is used.
"""
super().__init__(
nullable=nullable,
Expand All @@ -78,7 +89,11 @@ def __init__(
metadata=metadata,
description=description,
)
self.sqlalchemy_use_enum = sqlalchemy_use_enum
self.sqlalchemy_enum_name = sqlalchemy_enum_name
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a check here and fail loudly if a user sets sqlalchemy_enum_name but not sqlalchemy_use_enum

self._enum_class: type[enum.Enum] | None = None
if isclass(categories) and issubclass(categories, enum.Enum):
self._enum_class = categories
categories = (item.value for item in categories)
self.categories = list(categories)

Expand All @@ -91,12 +106,49 @@ def validate_dtype(self, dtype: PolarsDataType) -> bool:
return False
return self.categories == dtype.categories.to_list()

def sqlalchemy_column(self, name: str, dialect: sa.Dialect) -> sa.Column:
if self.sqlalchemy_use_enum:
return sa.Column(
name,
self._sqlalchemy_enum_type(dialect, column_name=name),
nullable=self.nullable,
primary_key=self.primary_key,
unique=self.unique,
autoincrement=False,
)
return super().sqlalchemy_column(name, dialect)
Comment on lines +111 to +116

def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:
if self.sqlalchemy_use_enum:
column_name = self._name or None
return self._sqlalchemy_enum_type(dialect, column_name=column_name)
category_lengths = [len(c) for c in self.categories]
if all(length == category_lengths[0] for length in category_lengths):
return sa.CHAR(category_lengths[0])
return sa.String(max(category_lengths))

def _sqlalchemy_enum_type(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is taking me a surprisingly long time to mentally parse this method. Do you think you could add some more comments and possibly structure the code to make it easier to understand in case the reader is less familiar with sa.Enum?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, is it possible to factor this into a match / case structure to make more obvious what the orthogonal cases are? Totally fine if that means the resulting code takes a few lines more if it then is more readable

self, _dialect: sa.Dialect, *, column_name: str | None
) -> sa_TypeEngine:
length = max(len(c) for c in self.categories)
kwargs: dict[str, Any] = {"length": length}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Concrete example: Do we need the kwargs to be a dict? I'd find it easier to read if we just passed length=length directly to sa.Enum without the dict deconstruction

name = self.sqlalchemy_enum_name
if self._enum_class is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would definitely be useful to add a comment here to say what self._enum_class is not None means conceptually (=="The dy.Enum was constructed from a python Enum")

if name is not None:
kwargs["name"] = name
return sa.Enum(self._enum_class, **kwargs)
Comment thread
AndreasAlbertQC marked this conversation as resolved.
if name is None:
name = column_name
if name is None:
raise ValueError(
"sqlalchemy_enum_name is required for dy.Enum with string categories "
"and sqlalchemy_use_enum=True when not building columns via "
"Schema.to_sqlalchemy_columns(). Alternatively, pass a Python "
"enum.Enum class as categories."
)
Comment on lines +142 to +146
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this already be checked in the constructor (early) rather than here (late)?

kwargs["name"] = name
return sa.Enum(*self.categories, **kwargs)

@property
def pyarrow_dtype(self) -> pa.DataType:
if len(self.categories) <= 2**8 - 1:
Expand Down
21 changes: 21 additions & 0 deletions docs/guides/features/sql-generation.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,27 @@ the maximal length of the string is inferred from the regular expression if poss
maximal lengths can be particularly important for primary key columns. Some database systems, such as Microsoft SQL Server, do not allow `VARCHAR(max)` columns (unbounded strings) to be used as primary keys.
```

## Native SQL enums (optional)

By default, {class}`~dataframely.Enum` maps to fixed-length `CHAR` or `VARCHAR` columns so stored values remain plain strings. For PostgreSQL setups that use database-level `ENUM` types (for example with Alembic autogenerate), set `sqlalchemy_use_enum=True`:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of saying CHAR/VARCHAR, can we phrase this in sqlalchemy types? I think that would be clearer given that we generate sqlalchemy, not SQL directly


```python
from enum import StrEnum

import dataframely as dy


class Status(StrEnum):
PENDING = "pending"
APPROVED = "approved"


class Staged(dy.Schema):
status = dy.Enum(Status, sqlalchemy_use_enum=True)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to expand this example by showing what sqlalchemy types would be generated here?

```

When `categories` is a Python `enum.Enum` subclass, SQLAlchemy uses the enum class name (lowercased) as the database enum type name. For string category lists, the SQL column name is used by default; override it with `sqlalchemy_enum_name` if needed. On dialects without native enums (such as Microsoft SQL Server), SQLAlchemy falls back to `VARCHAR` with a check constraint.

## Collections of multiple tables

If you have an entire `dy.Collection`, it's also easy to generate one table for each member table of the collection.
Expand Down
40 changes: 40 additions & 0 deletions tests/column_types/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,43 @@ def test_sequences_and_enums(
S = create_schema("test", {"x": dy.Enum(categories1)})
df = pl.DataFrame({"x": pl.Series(["a", "b"], dtype=pl.Enum(categories2))})
S.validate(df)


def test_matches_sqlalchemy_use_enum() -> None:
expr = pl.element()
assert dy.Enum(["a", "b"]).matches(dy.Enum(["a", "b"]), expr)
assert not dy.Enum(["a", "b"], sqlalchemy_use_enum=True).matches(
dy.Enum(["a", "b"]), expr
)
assert dy.Enum(["a", "b"], sqlalchemy_use_enum=True).matches(
dy.Enum(["a", "b"], sqlalchemy_use_enum=True), expr
)


def test_matches_sqlalchemy_enum_name() -> None:
expr = pl.element()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we missing the positive case here?

assert not dy.Enum(
["a", "b"],
sqlalchemy_use_enum=True,
sqlalchemy_enum_name="one",
).matches(
dy.Enum(
["a", "b"],
sqlalchemy_use_enum=True,
sqlalchemy_enum_name="two",
),
expr,
)


def test_as_dict_from_dict_sqlalchemy_enum_flags() -> None:
column = dy.Enum(
["a", "b"],
sqlalchemy_use_enum=True,
sqlalchemy_enum_name="my_enum",
)
data = column.as_dict(pl.element())
restored = dy.Enum.from_dict(data)
assert restored.sqlalchemy_use_enum is True
assert restored.sqlalchemy_enum_name == "my_enum"
assert restored.categories == ["a", "b"]
56 changes: 56 additions & 0 deletions tests/columns/test_sqlalchemy_columns.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) QuantCo 2025-2026
# SPDX-License-Identifier: BSD-3-Clause

from enum import Enum

import pytest

import dataframely as dy
Expand Down Expand Up @@ -171,3 +173,57 @@ def test_raise_for_object_column(dialect: Dialect) -> None:
NotImplementedError, match="SQL column cannot have 'Object' type."
):
dy.Object().sqlalchemy_dtype(dialect)


class _Status(str, Enum):
PENDING = "pending"
APPROVED = "approved"


@pytest.mark.parametrize(
("column", "dialect", "datatype"),
[
Comment on lines +178 to +185
(
dy.Enum(["foo", "bar"], sqlalchemy_use_enum=True),
PGDialect_psycopg2(),
"a",
),
(
dy.Enum(
["foo", "bar"],
sqlalchemy_use_enum=True,
sqlalchemy_enum_name="my_status",
),
PGDialect_psycopg2(),
"my_status",
),
(dy.Enum(_Status, sqlalchemy_use_enum=True), PGDialect_psycopg2(), "_status"),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a valid point, but covered by test_enum_sqlalchemy_native_python_enum_uses_member_values?

(
dy.Enum(["foo", "bar"], sqlalchemy_use_enum=True),
MSDialect_pyodbc(),
"VARCHAR(3)",
),
],
)
def test_enum_sqlalchemy_native(column: Column, dialect: Dialect, datatype: str) -> None:
schema = create_schema("test", {"a": column})
columns = schema.to_sqlalchemy_columns(dialect)
assert len(columns) == 1
assert columns[0].type.compile(dialect) == datatype


def test_enum_sqlalchemy_native_string_categories_use_column_name() -> None:
class TestSchema(dy.Schema):
status = dy.Enum(["foo", "bar"], sqlalchemy_use_enum=True)

column = TestSchema.columns()["status"]
assert column.sqlalchemy_dtype(PGDialect_psycopg2()).compile(
PGDialect_psycopg2()
) == "status"


def test_enum_sqlalchemy_native_string_categories_requires_name_without_column(
) -> None:
column = dy.Enum(["foo", "bar"], sqlalchemy_use_enum=True)
with pytest.raises(ValueError, match="sqlalchemy_enum_name is required"):
column.sqlalchemy_dtype(PGDialect_psycopg2())