From 52e8cc05a58f69e08f951191a969b89039a5e52e Mon Sep 17 00:00:00 2001 From: Shuvayan Ghosh Dastidar Date: Sat, 13 Nov 2021 01:59:04 +0530 Subject: [PATCH] sql parser init --- .vscode/settings.json | 4 + gsql/backend/auth.py | 12 +- gsql/backend/sqlite_manager.py | 104 ++++++++++++++++ gsql/exceptions/sqlparser_exception.py | 6 + gsql/frontend/shell/driver.py | 27 ++++- gsql/frontend/shell/shell.py | 139 ++++++++++++++++----- gsql/frontend/sql_parser.py | 161 +++++++++++++++++++++++++ 7 files changed, 420 insertions(+), 33 deletions(-) create mode 100644 .vscode/settings.json create mode 100644 gsql/backend/sqlite_manager.py create mode 100644 gsql/exceptions/sqlparser_exception.py create mode 100644 gsql/frontend/sql_parser.py diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..6f2cd04 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,4 @@ +{ + "python.pythonPath": "/Users/shuvayan/.pyenv/versions/3.9.7/bin/python", + "python.formatting.provider": "black" +} \ No newline at end of file diff --git a/gsql/backend/auth.py b/gsql/backend/auth.py index 0f8fdf1..c36c004 100644 --- a/gsql/backend/auth.py +++ b/gsql/backend/auth.py @@ -21,11 +21,21 @@ def save_token(self): with open(os.path.join(self.store_folder, "token.json"), "w") as token: token.write(self.creds.to_json()) + @staticmethod + def get_creds(): + """ + static method to get the credentials + """ + store_folder = os.path.join(os.path.expanduser("~"), ".gsql") + return Credentials.from_authorized_user_file( + os.path.join(store_folder, "token.json"), SCOPES + ) + def auth(self): # The file token.json stores the user's access and refresh tokens, and is # created automatically when the authorization flow completes for the first # time. - if os.path.exists(self.store_folder + "/token.json"): + if os.path.exists(os.path.join(self.store_folder, "token.json")): self.creds = Credentials.from_authorized_user_file( os.path.join(self.store_folder, "token.json"), SCOPES ) diff --git a/gsql/backend/sqlite_manager.py b/gsql/backend/sqlite_manager.py new file mode 100644 index 0000000..d0234bf --- /dev/null +++ b/gsql/backend/sqlite_manager.py @@ -0,0 +1,104 @@ +import sqlite3 +import os +import pandas as pd +from pandas.core.frame import DataFrame +import namegenerator + +class SQLiteManager: + ''' + a Manager class to help with SQL queries + ''' + + def __init__(self) -> None: + self.store_location = os.path.join(os.path.expanduser('~'), '.gsql', 'databases') + os.makedirs(self.store_location, exist_ok=True) + self.common_db_path = os.path.join( self.store_location , 'common.db') + + def _write_to_common(self, resultset): + + # convert the resultset dictionary into a set of tuples + insert_list = [] + + con = sqlite3.connect(self.common_db_path) + curr = con.cursor() + curr.execute("create table if not exists common (title varchar(50), id varchar(50) primary key, \ + nickname varchar(50))") + + table = pd.read_sql_query("select * from common", con) + for content in resultset: + if content['id'] not in table['id'].tolist(): + insert_list.append((content['title'], content['id'], namegenerator.gen())) + + + + curr.executemany("replace into common values(?, ?, ?);", insert_list) + + con.commit() + con.close() + + def _read_from_common(self) -> pd.DataFrame: + + con = sqlite3.connect(self.common_db_path) + table = pd.read_sql_query("select * from common", con) + con.close() + return table + + + def read_generic_select_statement(self, table_name, statement ) -> pd.DataFrame: + ''' + reads content from the database corresponding to a table name + and wraps it in a dataframe + + params + ------- + table_name : Name of the table (str) + statement : raw SQL statement to be passed on to SQLite + ''' + db_path = os.path.join(self.store_location, table_name + '.db') + con = sqlite3.connect(db_path) + table = pd.read_sql_query(statement, con) + con.close() + return table + + + def write_metadata(self, metadata) -> None: + + # preparing the metadata into tuple + insert_list = [] + db_id = metadata["spreadsheetId"] + db_name = metadata["title"] + + for sheet in metadata['sheets']: + properties = sheet['properties'] + insert_list.append((db_id, db_name, properties['sheetId'], properties['title'], + properties['gridProperties']['rowCount'], properties['gridProperties']['columnCount'])) + + + metadata_path = os.path.join( self.store_location , 'metadata_{}.db'.format(db_id)) + con = sqlite3.connect(metadata_path) + curr = con.cursor() + curr.execute("create table if not exists metadata (db_id varchar(50), db_name varchar(50), \ + sheet_id varchar(20) primary key, sheet_title varchar(50), row_count integer, col_count integer)") + curr.executemany("replace into metadata values(?, ?, ?, ?, ?, ?)" , insert_list) + con.commit() + con.close() + + def read_metadata(self, db_id) -> pd.DataFrame: + ''' + reads out the data corresponding to a particular ID + ''' + + metadata_path = os.path.join( self.store_location , 'metadata_{}.db'.format(db_id)) + con = sqlite3.connect(metadata_path) + table = pd.read_sql_query("select * from metadata", con) + con.close() + return table + + + + + + + + + diff --git a/gsql/exceptions/sqlparser_exception.py b/gsql/exceptions/sqlparser_exception.py new file mode 100644 index 0000000..cb73ab9 --- /dev/null +++ b/gsql/exceptions/sqlparser_exception.py @@ -0,0 +1,6 @@ + + +class SQLStatmentException(Exception): + def __init__(self, message="Invalid SQL statement") -> None: + super().__init__(message) + self.message = message \ No newline at end of file diff --git a/gsql/frontend/shell/driver.py b/gsql/frontend/shell/driver.py index 3abb7e3..00217fd 100644 --- a/gsql/frontend/shell/driver.py +++ b/gsql/frontend/shell/driver.py @@ -1,9 +1,12 @@ +from gsql.backend.sqlite_manager import SQLiteManager from gsql.frontend.constants import Commands from gsql.logging import logger from gsql.frontend.shell.shell import GSQLShell from rich import print from gsql.backend.auth import Auth from gsql.console import console +from gsql.backend.api_handler import ApiHandler +import sys class GSQLDriver: @@ -14,11 +17,13 @@ class GSQLDriver: def __init__(self, action: str) -> None: self.action = action logger.debug("GSQL called with action :{}".format(self.action)) - self.shell_instance = GSQLShell() + self.shell_instance = None + self.auth = Auth() + self.api = None + self.sqlite_manager = SQLiteManager() def authenticate(self): - auth = Auth() - err = auth.auth() + err = self.auth.auth() if err: logger.error("Authentication failed: {}".format(err)) console.print("[red]Authentication failed!!!") @@ -43,6 +48,22 @@ def show_help(self): print("help") def start_shell(self): + # check if not authenticated force the user to authenticate + err = self.auth.auth() + if err: + logger.error("Authentication failed: {}".format(err)) + console.print("[red]Authentication failed!!!") + sys.exit(0) + + # fetch all databases before starting gsql shell + with console.status( + "Preparing and personalizing GSQL for you ....", spinner="bouncingBall" + ): + if self.api is None: + self.api = ApiHandler(Auth.get_creds()) + result = self.api.getAllSpreadsheetInfo() + self.sqlite_manager._write_to_common(result) + self.shell_instance = GSQLShell() self.shell_instance.cmdloop() def error_(self): diff --git a/gsql/frontend/shell/shell.py b/gsql/frontend/shell/shell.py index 56bbeec..8e31afc 100644 --- a/gsql/frontend/shell/shell.py +++ b/gsql/frontend/shell/shell.py @@ -1,6 +1,12 @@ +from gsql.exceptions.sqlparser_exception import SQLStatmentException +from gsql.backend.auth import Auth +from gsql.backend.api_handler import ApiHandler +from gsql.backend.sqlite_manager import SQLiteManager +from gsql.frontend.sql_parser import SQLParser import os import cmd + try: import readline except ImportError: @@ -8,7 +14,6 @@ from gsql.console import console from rich.table import Table -import time gsql_text = """ @@ -36,17 +41,30 @@ def __init__(self) -> None: os.path.expanduser("~"), ".gsql", "gsql_history.txt" ) self.histfile_size = 1000 - # dummy data - self.sheets = [ - {"name": "Fun Sheet", "id": "1WooAUEpz7ECEK2M7YIS3WzNK2c"}, - {"name": "Another Fun Sheet", "id": "qjdq286382bhd27872gr44"}, - ] + + self.sql_parser = SQLParser() + self.sqlite_manager = SQLiteManager() + self.api = ApiHandler(Auth.get_creds()) + self._get_sheets() + + def _get_sheets(self): + df = self.sqlite_manager._read_from_common() + self.sheets = df def preloop(self): if readline and os.path.exists(self.history_file): readline.read_history_file(self.history_file) def default(self, line): + """ + Should be a either a valid SQL statement or error + """ + try: + self.sql_parser.parse_statement(line) + except SQLStatmentException as e: + console.print("[red]{}[/]".format(e.message)) + + def _show_error(self, line): """ Prints out an error message to the console """ @@ -60,27 +78,52 @@ def do_show(self, args): in your account. Usage : show databases - Prints out name and id of the spreadsheets + : show tables + - Prints out the available tables in a databases ( sheets in a spreadsheet ) """ arg_tokens = args.split() - if len(arg_tokens) != 1 or arg_tokens[0].lower() != "databases": - self.default("show " + args) + if len(arg_tokens) != 1 or arg_tokens[0].lower() not in ["databases", "tables"]: + self._show_error("show " + args) return - # TODO get the details from API + if arg_tokens[0].lower() == "tables": + if self.sql_parser.database is None: + console.print( + "[red]error: gsql: Not connected to database, please first connect to a database[/]" + ) + return with console.status("Getting your data ....", spinner="bouncingBall"): - # call API synchronous call - time.sleep(3) - - table = Table(title="Your databases") - table.add_column("Serial", justify="right", no_wrap=True) - table.add_column("Name", style="green") - table.add_column("ID", justify="right") - # limiting display upto first 20 sheets - to_be_shown = self.sheets[:20] - for i, item in enumerate(to_be_shown): - table.add_row(str(i + 1), item["name"], item["id"]) + if arg_tokens[0].lower() == "databases": + table = Table(title="Your databases") + table.add_column("Serial", justify="right", no_wrap=True) + table.add_column("Title", style="green") + table.add_column("ID", justify="right") + table.add_column("Nickname", style="green") + # limiting display upto first 20 sheets + to_be_shown = self.sheets[:20] + for i, item in to_be_shown.iterrows(): + table.add_row( + str(i + 1), item["title"], item["id"], item["nickname"] + ) + else: + table = Table(title="Tables") + table.add_column("Serial", justify="right", no_wrap=True) + table.add_column("Name", style="green") + table.add_column("ID", justify="right") + table.add_column("Row Count") + table.add_column("Column Count") + + data = self.sqlite_manager.read_metadata(self.sql_parser.database) + for i, item in data.iterrows(): + table.add_row( + str(i + 1), + item["sheet_title"], + str(item["sheet_id"]), + str(item["row_count"]), + str(item["col_count"]), + ) console.print(table) @@ -90,27 +133,60 @@ def do_connect(self, args): ------------ Command to connect to the database (google sheet) whose id user provides - Usage : connect + Usage : connect - Connects to the database for further operations on it """ arg_tokens = args.split() if len(arg_tokens) != 1: - self.default("connect " + args) + self._show_error("connect " + args) return - sheet_id = arg_tokens[0] + sheet_arg = arg_tokens[0] with console.status("Attempting to connect ....", spinner="bouncingBall"): - # TODO call API synchronous call - if sheet_id not in [item["id"] for item in self.sheets]: - console.print(f"[red]Database with id: {sheet_id} not found[/]") + is_sheet_name = sheet_arg in self.sheets["title"].tolist() + is_sheet_id = sheet_arg in self.sheets["id"].tolist() + is_sheet_nickname = sheet_arg in self.sheets["nickname"].tolist() + if not is_sheet_name and not is_sheet_id and not is_sheet_nickname: + console.print(f"[red]Database with arg : {sheet_arg} not found[/]") return - time.sleep(3) - sheet_name = list(filter(lambda x: x["id"] == sheet_id, self.sheets))[0]["name"] - self.prompt = "GSQL (" + sheet_name.replace(" ", "")[:10] + ") > " + if is_sheet_name: + sheet_name = sheet_arg + sheet_id = self.sheets.loc[ + self.sheets["title"] == sheet_arg, "id" + ].iloc[0] + elif is_sheet_id: + sheet_name = self.sheets.loc[ + self.sheets["id"] == sheet_arg, "title" + ].iloc[0] + sheet_id = sheet_arg + else: + sheet_name = self.sheets.loc[ + self.sheets["nickname"] == sheet_arg, "title" + ].iloc[0] + sheet_id = self.sheets.loc[ + self.sheets["nickname"] == sheet_arg, "id" + ].iloc[0] + + # get the data and store it in sqlite + metadata = self.api.getSpreadsheetInfo(sheet_id) + self.sqlite_manager.write_metadata(metadata) + self.prompt = "GSQL (" + sheet_name.replace(" ", "")[:10] + ") > " + self.sql_parser.database = sheet_id console.print(f"[green]Connected to {str(sheet_id)}[/]") + def emptyline(self): + """Called when an empty line is entered in response to the prompt. + + If this method is not overridden, it repeats the last nonempty + command entered. + + """ + if self.lastcmd: + self.lastcmd = "" + return self.onecmd("\n") + def do_disconnect(self, args): """ Disconnect @@ -123,6 +199,7 @@ def do_disconnect(self, args): # TODO make api call and remove from cache self.prompt = "GSQL > " + del self.sql_parser.database console.print("[green]Disconnected successfully[/]") def do_clear(self, args): @@ -138,6 +215,10 @@ def precmd(self, line): """ Converts the current line to lowercase """ + + # check empty command + if line == "": + return line tokens = str(line).split() if tokens[0].lower() == "connect": line = "connect " + " ".join(token for token in tokens[1:]) diff --git a/gsql/frontend/sql_parser.py b/gsql/frontend/sql_parser.py new file mode 100644 index 0000000..b845882 --- /dev/null +++ b/gsql/frontend/sql_parser.py @@ -0,0 +1,161 @@ + +from gsql.exceptions.sqlparser_exception import SQLStatmentException +import sqlparse +from sqlparse.tokens import Keyword, DML, DDL, Punctuation +import sqlvalidator + + + +class SQLTokens: + + SELECT = 'SELECT' + UPDATE = 'UPDATE' + DELETE = 'DELETE' + INSERT = 'INSERT' + ALTER = 'ALTER' + CREATE = 'CREATE' + ERROR = 'ERROR' + + +class SQLDTO: + + def __init__(self, raw_statement, table_name='sample table', + type_statement: SQLTokens = SQLTokens.SELECT, + affected_columns = [],updated_values = {}, deleted_columns = {}, + filter_function = None + ) -> None: + ''' + raw_statement : The raw SQL statement passed to the GSQL shell + table_name : The table corresponding to the query + type_statement : Type of the SQL query determined by SQLTokens + affected_columns : The name of the columns affected by the query + will be [] in case of select queries. + updated_values : a mapping of table names to the updated values + will be empty in all queries except update and insert + filter_function : A lambda expression to filter rows based on parameters will be None + for SQL queries + ''' + pass + + +class SQLParser: + + def __init__(self ) -> None: + self.db_id = None + self.db_name = None + + @property + def database(self): + return self.db_id + + @database.setter + def database(self, db_id): + self.db_id = db_id + + @database.deleter + def database(self): + self.db_id = None + + def _validate_statement(self, raw_statement): + parsed = sqlvalidator.parse(raw_statement) + if not parsed.is_valid(): + raise SQLStatmentException("gsql: error: SQL SELECT Query not valid : " + str(parsed.errors)) + + + def classify_statement(self, parsed): + first_token = parsed.tokens[0] + if first_token.ttype == DML and first_token.value.upper() == SQLTokens.SELECT: + return SQLTokens.SELECT + elif first_token.ttype == DML and first_token.value.upper() == SQLTokens.UPDATE: + return SQLTokens.UPDATE + elif first_token.ttype == DML and first_token.value.upper() == SQLTokens.INSERT: + return SQLTokens.INSERT + elif first_token.ttype == DML and first_token.value.upper() == SQLTokens.DELETE: + return SQLTokens.DELETE + elif first_token.ttype == DDL and first_token.value.upper() == SQLTokens.ALTER: + return SQLTokens.ALTER + elif first_token.ttype == DDL and first_token.value.upper() == SQLTokens.CREATE: + return SQLTokens.CREATE + else: + return SQLTokens.ERROR + + + def handle_select_statement(self, statement): + # self._validate_statement(statement) + # TODO connect to sqlite to get the result from sqlite + print('SELECT' , statement) + + def handle_update_statement(self, statement): + ''' + prepare a DTO based on parsed update statement + The DTO is to be passed on for making further API calls + ''' + print('UPDATE' , statement) + + def handle_insert_statement(self, statement): + ''' + prepare a DTO based on parsed insert statement + The DTO is to be passed on for making further API calls + ''' + print('INSERT' , statement) + + def handle_delete_statement(self, statement): + ''' + prepare a DTO based on parsed delete statement + The DTO is to be passed on for making further API calls + ''' + print('DELETE' , statement) + + def handle_alter_statement(self, statement): + ''' + prepare a DTO based on parsed alter statement + The DTO is to be passed on for making further API calls + ''' + print('ALTER' , statement) + + def handle_create_statement(self, statement): + ''' + prepare a DTO based on parsed create statement + The DTO is to be passed on for making further API calls + ''' + print('CREATE' , statement) + + + def handle_statement(self, statement): + class_token = self.classify_statement(statement) + if class_token == SQLTokens.ERROR: + raise SQLStatmentException("gsql: error: invalid statement") + if class_token == SQLTokens.SELECT: + self.handle_select_statement(statement) + if class_token == SQLTokens.UDPATE: + self.handle_update_statement(statement) + if class_token == SQLTokens.INSERT: + self.handle_insert_statement(statement) + if class_token == SQLTokens.DELETE: + self.handle_delete_statement(statement) + if class_token == SQLTokens.ALTER: + self.handle_alter_statement(statement) + if class_token == SQLTokens.DELETE: + self.handle__statement(statement) + + + + def parse_statement(self, raw_statement: str): + ''' + raw_statement : raw SQL statement + ''' + # get rid of the white spaces + raw_statement = raw_statement.strip() + parsed = sqlparse.parse(raw_statement) + if self.db_id is None: + raise SQLStatmentException("Please connect to a database before continuing, \ + type help connect for more info") + + for statement in parsed: + tokens = statement.tokens + if tokens[-1].ttype != Punctuation or tokens[-1].value != ';': + raise SQLStatmentException("Expected ; at the end of a SQL statement") + self.handle_statement(statement) + + +