diff --git a/sqlmodel/main.py b/sqlmodel/main.py index c551afea36..fed132dbcc 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -613,11 +613,11 @@ def get_config(name: str) -> Any: # TODO: remove this in the future new_cls.model_config["read_with_orm_mode"] = True # ty: ignore[invalid-key] - config_registry = get_config("registry") + config_registry = kwargs.get("registry", Undefined) if config_registry is not Undefined: config_registry = cast(registry, config_registry) # If it was passed by kwargs, ensure it's also set in config - new_cls.model_config["registry"] = config_table + new_cls.model_config["registry"] = config_registry setattr(new_cls, "_sa_registry", config_registry) # noqa: B010 setattr(new_cls, "metadata", config_registry.metadata) # noqa: B010 setattr(new_cls, "__abstract__", True) # noqa: B010 diff --git a/tests/test_registry_kwarg.py b/tests/test_registry_kwarg.py new file mode 100644 index 0000000000..c121b009ca --- /dev/null +++ b/tests/test_registry_kwarg.py @@ -0,0 +1,30 @@ +"""Tests for SQLModelMetaclass.__new__ registry kwarg handling.""" + +from sqlalchemy.orm import registry +from sqlmodel import SQLModel + + +def test_custom_registry_base_stores_registry_in_model_config() -> None: + """model_config['registry'] must hold the registry object passed as kwarg.""" + custom_registry = registry() + + class MyBase(SQLModel, registry=custom_registry): + pass + + stored = MyBase.model_config.get("registry") + assert stored is custom_registry, ( + f"model_config['registry'] should be the custom registry, got {stored!r}" + ) + + +def test_custom_registry_base_sets_sa_registry() -> None: + """_sa_registry must reference the registry object passed as kwarg.""" + custom_registry = registry() + + class MyBase2(SQLModel, registry=custom_registry): + pass + + sa_registry = getattr(MyBase2, "_sa_registry", None) + assert sa_registry is custom_registry, ( + f"_sa_registry should be the custom registry, got {sa_registry!r}" + )