Skip to content

Subquery fix #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/dataneuron/core/nlp_helpers/cte_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ def extract_main_query(parsed):
def filter_cte(cte_part, filter_function, client_id):
filtered_ctes = []

is_recursive = False

for token in cte_part.tokens:
if token.ttype is Keyword and token.value.upper() == 'RECURSIVE':
is_recursive = True

def process_cte(token):
if isinstance(token, sqlparse.sql.Identifier):
cte_name = token.get_name()
Expand All @@ -57,7 +63,7 @@ def process_cte(token):
# Remove outer parentheses
inner_query_str = str(inner_query)[1:-1]
filtered_inner_query = filter_function(
sqlparse.parse(inner_query_str)[0], client_id)
sqlparse.parse(inner_query_str)[0], client_id, cte_name)
filtered_ctes.append(f"{cte_name} AS ({filtered_inner_query})")

for token in cte_part.tokens:
Expand All @@ -68,7 +74,10 @@ def process_cte(token):
process_cte(token)

if filtered_ctes:
filtered_cte_str = "WITH " + ",\n".join(filtered_ctes)
if is_recursive:
filtered_cte_str = "WITH RECURSIVE " + ",\n".join(filtered_ctes)
else:
filtered_cte_str = "WITH " + ",\n".join(filtered_ctes)
else:
filtered_cte_str = ""
return filtered_cte_str
Expand Down
109 changes: 79 additions & 30 deletions src/dataneuron/core/sql_query_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,18 @@ def apply_client_filter(self, sql_query: str, client_id: int) -> str:

return self._cleanup_whitespace(str(result))

def _apply_filter_recursive(self, parsed, client_id):
def _apply_filter_recursive(self, parsed, client_id, cte_name: str = None):
if self._is_cte_query(parsed):
return handle_cte_query(parsed, self._apply_filter_recursive, client_id)

if isinstance(parsed, Token) and parsed.ttype is DML:
return self._apply_filter_to_single_query(str(parsed), client_id)
elif self._contains_set_operation(parsed):
return self._handle_set_operation(parsed, client_id)
elif self._contains_subquery(parsed):
return self._handle_subquery(parsed, client_id)
else:
filtered_query = self._apply_filter_to_single_query(
str(parsed), client_id)
return self._handle_where_subqueries(sqlparse.parse(filtered_query)[0], client_id)
for token in parsed.tokens:
if isinstance(token, Token) and token.ttype is DML:
if self._contains_set_operation(parsed):
return self._handle_set_operation(parsed, client_id, True, cte_name) if cte_name else self._handle_set_operation(parsed, client_id)
elif self._contains_subquery(parsed):
return self._handle_subquery(parsed, client_id)
else:
return self._apply_filter_to_single_query(str(parsed), client_id)

def _contains_set_operation(self, parsed):
set_operations = ('UNION', 'INTERSECT', 'EXCEPT')
Expand Down Expand Up @@ -227,7 +225,7 @@ def _inject_where_clause(self, parsed, where_clause):

return str(parsed)

def _handle_set_operation(self, parsed, client_id):
def _handle_set_operation(self, parsed, client_id, is_cte: bool = False, cte_name: str = None):
print("Handling set operation")
# Split the query into individual SELECT statements
statements = []
Expand All @@ -253,9 +251,14 @@ def _handle_set_operation(self, parsed, client_id):
# Apply the filter to each SELECT statement
filtered_statements = []
for stmt in statements:
filtered_stmt = self._apply_filter_to_single_query(stmt, client_id)
filtered_statements.append(filtered_stmt)
print(f"Filtered statement: {filtered_stmt}")
if is_cte:
filtered_stmt = self._apply_filter_to_single_CTE_query(stmt, client_id, cte_name)
filtered_statements.append(filtered_stmt)
print(f"Filtered statement: {filtered_stmt}")
else:
filtered_stmt = self._apply_filter_to_single_query(stmt, client_id)
filtered_statements.append(filtered_stmt)
print(f"Filtered statement: {filtered_stmt}")

# Reconstruct the query
result = f" {set_operation} ".join(filtered_statements)
Expand Down Expand Up @@ -363,45 +366,51 @@ def _cleanup_whitespace(self, query: str) -> str:
def _handle_subquery(self, parsed, client_id):
result = []
tokens = parsed.tokens if hasattr(parsed, 'tokens') else [parsed]
mainquery = []

for token in tokens:
if isinstance(token, Identifier) and token.has_alias():
if isinstance(token.tokens[0], Parenthesis):
mainquery.append(" PLACEHOLDER ")
subquery = token.tokens[0].tokens[1:-1]
subquery_str = ' '.join(str(t) for t in subquery)
filtered_subquery = self._apply_filter_recursive(
sqlparse.parse(subquery_str)[0], client_id)
alias = token.get_alias()
result.append(f"({filtered_subquery}) AS {alias}")
AS_keyword = next((t for t in token.tokens if t.ttype == sqlparse.tokens.Keyword and t.value.upper() == 'AS'), None) # Checks for existence of 'AS' keyword

if AS_keyword:
result.append(f"({filtered_subquery}) AS {alias}")
else:
result.append(f"({filtered_subquery}) {alias}")
else:
result.append(str(token))
mainquery.append(str(token))

elif isinstance(token, Parenthesis):
mainquery.append(" PLACEHOLDER ")
subquery = token.tokens[1:-1]
subquery_str = ' '.join(str(t) for t in subquery)
filtered_subquery = self._apply_filter_recursive(
sqlparse.parse(subquery_str)[0], client_id)
result.append(f"({filtered_subquery})")
elif isinstance(token, Where):

elif isinstance(token, Where) and 'IN' in str(parsed):
try:
filtered_where = self._handle_where_subqueries(
token, client_id)
result.append(str(filtered_where))
except Exception as e:
result.append(str(token))
else:
# Preserve whitespace tokens
if token.is_whitespace:
result.append(str(token))
else:
# Add space before and after non-whitespace tokens, except for punctuation
if result and not result[-1].endswith(' ') and not str(token).startswith((')', ',', '.')):
result.append(' ')
result.append(str(token))
if not str(token).endswith(('(', ',')):
result.append(' ')
mainquery.append(str(token))

final_result = ''.join(result).strip()
return final_result
mainquery = ''.join(mainquery).strip()
if ' IN ' in str(parsed):
return f"{mainquery} {result[0]}"
else:
filtered_mainquery = self._apply_filter_to_single_query(mainquery, client_id)
query = filtered_mainquery.replace("PLACEHOLDER", result[0])
return query

def _handle_where_subqueries(self, where_clause, client_id):
if self._is_cte_query(where_clause):
Expand Down Expand Up @@ -508,3 +517,43 @@ def _extract_main_table(self, where_clause):
if isinstance(token, Identifier):
return token.get_real_name()
return None

def _apply_filter_to_single_CTE_query(self, sql_query: str, client_id: int, cte_name: str) -> str:
parts = sql_query.split(' GROUP BY ')
main_query = parts[0]

group_by = f" GROUP BY {parts[1]}" if len(parts) > 1 else ""
parsed = sqlparse.parse(main_query)[0]
tables_info = self._extract_tables_info(parsed)

filters = []
_table_ = []

for table_info in tables_info:
if table_info['name'] != cte_name:
table_dict = {
"name": table_info['name'],
"alias": table_info['alias'],
"schema": table_info['schema']
}
_table_.append(table_dict)

matching_table = self._find_matching_table(_table_[0]['name'], _table_[0]['schema'])

if matching_table:
client_id_column = self.client_tables[matching_table]
table_reference = _table_[0]['alias'] or _table_[0]['name']

filters.append(f'{self._quote_identifier(table_reference)}.{self._quote_identifier(client_id_column)} = {client_id}')

if filters:
where_clause = " AND ".join(filters)
if 'WHERE' in main_query.upper():
where_parts = main_query.split('WHERE', 1)
result = f"{where_parts[0]} WHERE {where_parts[1].strip()} AND {where_clause}"
else:
result = f"{main_query} WHERE {where_clause}"
else:
result = main_query

return result + group_by
121 changes: 61 additions & 60 deletions tests/core/test_sql_query_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ def test_subquery_in_from(self):
expected = 'SELECT * FROM (SELECT * FROM orders WHERE "orders"."user_id" = 1) AS subq'
self.assertEqual(self.filter.apply_client_filter(query, 1), expected)

# def test_subquery_in_join(self):
# query = 'SELECT o.* FROM orders o JOIN (SELECT * FROM products) p ON o.product_id = p.id'
# expected = 'SELECT o.* FROM orders o JOIN (SELECT * FROM products WHERE "products"."company_id" = 1) p ON o.product_id = p.id WHERE "o"."user_id" = 1'
# self.assertEqual(self.filter.apply_client_filter(query, 1), expected)
def test_subquery_in_join(self):
query = 'SELECT o.* FROM orders o JOIN (SELECT * FROM products) p ON o.product_id = p.id'
expected = 'SELECT o.* FROM orders o JOIN (SELECT * FROM products WHERE "products"."company_id" = 1) p ON o.product_id = p.id WHERE "o"."user_id" = 1'
self.assertEqual(self.filter.apply_client_filter(query, 1), expected)

def test_nested_subqueries(self):
query = 'SELECT * FROM (SELECT * FROM (SELECT * FROM orders) AS inner_subq) AS outer_subq'
Expand All @@ -123,7 +123,8 @@ def setUp(self):
'products': 'company_id',
'inventory.items': 'organization_id',
'items': 'organization_id',
'customers': 'customer_id'
'customers': 'customer_id',
'categories': 'company_id'
}
self.filter = SQLQueryFilter(
self.client_tables, schemas=['main', 'inventory'])
Expand Down Expand Up @@ -217,61 +218,61 @@ def test_multiple_ctes(self):
self.assertSQLEqual(
self.filter.apply_client_filter(query, 1), expected)

# def test_cte_with_subquery(self):
# query = '''
# WITH top_products AS (
# SELECT p.id, p.name, SUM(o.quantity) as total_sold
# FROM products p
# JOIN (SELECT * FROM orders WHERE status = 'completed') o ON p.id = o.product_id
# GROUP BY p.id, p.name
# ORDER BY total_sold DESC
# LIMIT 10
# )
# SELECT * FROM top_products
# '''
# expected = '''
# WITH top_products AS (
# SELECT p.id, p.name, SUM(o.quantity) as total_sold
# FROM products p
# JOIN (SELECT * FROM orders WHERE status = 'completed' AND "orders"."user_id" = 1) o ON p.id = o.product_id
# WHERE "p"."company_id" = 1
# GROUP BY p.id, p.name
# ORDER BY total_sold DESC
# LIMIT 10
# )
# SELECT * FROM top_products
# '''
# self.assertSQLEqual(
# self.filter.apply_client_filter(query, 1), expected)

# def test_recursive_cte(self):
# query = '''
# WITH RECURSIVE category_tree AS (
# SELECT id, name, parent_id, 0 AS level
# FROM categories
# WHERE parent_id IS NULL
# UNION ALL
# SELECT c.id, c.name, c.parent_id, ct.level + 1
# FROM categories c
# JOIN category_tree ct ON c.parent_id = ct.id
# )
# SELECT * FROM category_tree
# '''
# expected = '''
# WITH RECURSIVE category_tree AS (
# SELECT id, name, parent_id, 0 AS level
# FROM categories
# WHERE parent_id IS NULL AND "categories"."company_id" = 1
# UNION ALL
# SELECT c.id, c.name, c.parent_id, ct.level + 1
# FROM categories c
# JOIN category_tree ct ON c.parent_id = ct.id
# WHERE "c"."company_id" = 1
# )
# SELECT * FROM category_tree
# '''
# self.assertSQLEqual(
# self.filter.apply_client_filter(query, 1), expected)
def test_cte_with_subquery(self):
query = '''
WITH top_products AS (
SELECT p.id, p.name, SUM(o.quantity) as total_sold
FROM products p
JOIN (SELECT * FROM orders WHERE status = 'completed') o ON p.id = o.product_id
GROUP BY p.id, p.name
ORDER BY total_sold DESC
LIMIT 10
)
SELECT * FROM top_products
'''
expected = '''
WITH top_products AS (
SELECT p.id, p.name, SUM(o.quantity) as total_sold
FROM products p
JOIN (SELECT * FROM orders WHERE status = 'completed' AND "orders"."user_id" = 1) o ON p.id = o.product_id
WHERE "p"."company_id" = 1
GROUP BY p.id, p.name
ORDER BY total_sold DESC
LIMIT 10
)
SELECT * FROM top_products
'''
self.assertSQLEqual(
self.filter.apply_client_filter(query, 1), expected)

def test_recursive_cte(self):
query = '''
WITH RECURSIVE category_tree AS (
SELECT id, name, parent_id, 0 AS level
FROM categories
WHERE parent_id IS NULL
UNION ALL
SELECT c.id, c.name, c.parent_id, ct.level + 1
FROM categories c
JOIN category_tree ct ON c.parent_id = ct.id
)
SELECT * FROM category_tree
'''
expected = '''
WITH RECURSIVE category_tree AS (
SELECT id, name, parent_id, 0 AS level
FROM categories
WHERE parent_id IS NULL AND "categories"."company_id" = 1
UNION ALL
SELECT c.id, c.name, c.parent_id, ct.level + 1
FROM categories c
JOIN category_tree ct ON c.parent_id = ct.id
WHERE "c"."company_id" = 1
)
SELECT * FROM category_tree
'''
self.assertSQLEqual(
self.filter.apply_client_filter(query, 1), expected)


if __name__ == '__main__':
Expand Down