Writing unit tests for SQL queries
We had a bug at work regarding the creation of an SQL query, which no tests covered. I'll go through what went wrong and how we broke the code into testable pieces
Following my recent blog post, we were the subject of a cosmic joke of sorts, when we experienced a bug today. Digging into the code, I realized that the bug was introduced recently when we refactored part of the codebase concerned with the building of an SQL query - which was not covered by any tests. I’ve linked the article below, however this post is very self-contained.
The context
To load data into our data platform we leverage reusable clients that streamline the process of taking a given table, optionally utilizing partitioning to minimize what data we load in each run. Having a class for each integrated database technology, that all implement an interface requiring a method similar to the following:
def fetch(
self,
table: DatabaseTable,
partition: Partition | None = None,
) -> Dataframe:
...
And this was difficult to write tests for; we don’t control the state. However, as I alluded to in my prior post, one has to ask; what exactly do we want to test?
The bug
The regression that happened was that, previously, the generated query was:
SELECT
*
FROM
dbo.friends
WHERE
date_of_birth BETWEEN '19900101' AND '19900101'
which now was:
SELECT
*
FROM
dbo.friends
WHERE
date_of_birth BETWEEN 19900101 AND 19900101
In hindsight, I am surprised you can query a string column with integers - you can, but you get an empty result.
The pull request that caused the bug moved the date formatting1 into a separate method called format_date
.
Reviewing the code, no one spotted that the apostrophe was accidentally dropped. It’s a small thing, and humans are flawed (at least I am). This was easily fixed, naturally, and we noticed it almost instantly. Nevertheless, I feared poetic irony if we didn’t cover this with unit tests. But how? The fetch method wasn’t easily testable.
Refactoring for testability
So what we learned from my previous post was that we don’t need to test the entire method, simply how we build the SQL query.
I looked at the implementation of fetch which (slightly rewritten) looked like this:
def fetch(
self,
table: DatabaseTable,
partition: Partition | None = None,
) -> pd.DataFrame:
query = f"SELECT * FROM {table.schema_name}.{table.table_name}"
if partition:
query += self._create_where_clause(partition)
return self._sql(query)
However not knowing which parts required state and database connection I went to the lowest level, and worked my way up; the new format_date
method never accessed the state of the underlying class; it was a pure function - so I moved it into a function, rather than a private method.
Before:
def _format_date(
self,
date: datetime,
) -> str:
return datetime.strftime(date, "%Y%m%d")
After:
def format_date(
date: datetime,
) -> str:
return datetime.strftime(date, "%Y%m%d")
Suddenly, that was trivial to test:
def test__format_date():
assert format_date(datetime(2023,1,1)) == "20230101"
Great! So far so good. Moving my way up, I looked at the _create_where_clause
method, which was essentially the source of the bug. However, it did utilize the client to figure out the datatype of the column and to figure out exactly how to query the field. Therefore less trivial:
def _create_where_clause(
self,
partition: Partition | None = None,
) -> pd.DataFrame:
# Something that required state
...
return f"""
WHERE
{start} <= {partition.column_name}
AND {partition.column_name} < {end}
"""
By simplifying the above code, the point is even quite obvious; the method was doing multiple things. Both utilizing the state to define start
and end
, and then building the query.
By rewriting the above code into the following:
def _create_where_clause(
self,
partition: Partition | None = None,
) -> pd.DataFrame:
# Something that required state
...
return build_where_clause(partition, column_type)
I suddenly had something pure, that could be tested:
def test_build_where_clause():
output = build_where_clause(
Partition(
column_name="timestamp",
start=datetime(2023, 9, 1),
end=datetime(2023, 9, 10),
),
ColumnType.VARCHAR2,
)
assert output == "WHERE timestamp BETWEEN '20230901' AND '20230910'"
The above test failed because of the aforementioned bug; however now we had something concrete to fix and validate whether a fix worked.
Eureka.
The source system has this specific date as a VARCHAR, formatted as YYYYMMDD