谢谢你留下时光匆匆
用Python进行SQL练习的简单方法

在进行数据分析/数据科学方向上求职时,SQL题目的练习是不可或缺的。对于自己所写的 SQL 代码,最好的验证方式是跑出代码的数据结果。常见的刷题平台(如牛客网,leetcode)都支持这样的功能。但不在这些刷题平台的SQL题目,想要去运行自己所写的答案,就需要搭建一个能运行 SQL 的环境(如本地MySQL),这对于非技术背景的同学可能会比较困难。我最近找到一个简单的方案,只要能运行Python,安装相关包后,就能运行 SQL、进行 SQL 练习。本文该方案进行介绍。

方案非常简单:使用Pandas载入数据,再利用pandasql包,即可使用SQL语句操作Pandas数据,执行SQL来得到返回数据。整个操作流程如下:

首先,利用 pip 安装 pandaspandasql Python 包

1
pip install pandas, pandasql

在 Python 环境中,将数据存储到 pandas DataFrame

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
df = pd.DataFrame([
    ('Adam', 18),
    ('Bod', 20),
    ('John', 35),
    ('David', 40),
    ('Frank', 47)
], columns=['name', 'age'])

# >>> df
#     name  age
# 0   Adam   18
# 1    Bod   20
# 2   John   35
# 3  David   40
# 4  Frank   47

用一个变量存下想要运行的SQL语句。需要注意的是,SQL语句中 From 后面跟随的表名为上面DataFrame对应的变量名(这里即为df

1
2
3
4
5
6
7
8
SQL = """
SELECT
    *
FROM
    df
WHERE
    age < 30
"""

用下面的语句运行SQL

1
2
3
4
5
6
7
8
9
from pandasql import sqldf
pysqldf = lambda q: sqldf(q, globals())

pysqldf(sql)

# 运行结果为:
#    name  age
# 0  Adam   18
# 1   Bod   20

需要注意的是,pandasql背后使用的是SQLite引擎执行运算,SQLite在一些函数的语法上与HiveSQL略有不同,比如,HiveSQL中的datediff函数在sqlite中需要用Cast(JulianDay(date1) - JulianDay(date2) AS Integer)。如果想完全使用HiveSQL语法,可以搭建一个本地Spark环境,使用SparkSQL运行结果,后续我可能再写一篇文章单独介绍。

Leetcode 题目中的题目会给出形如下方这样展示表数据的样例文本

1
2
3
4
5
6
7
8
+----+-------+
| id | name  |
+----+-------+
| 1  | Joe   |
| 2  | Henry |
| 3  | Sam   |
| 4  | Max   |
+----+-------+

为了练习方便,这里写了一个简单的 Python 函数,将这些样例文本转换为 pandas DataFrame

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def clean_leetcode_table(table_str: str):
    """
    将Leetcode中出现的表(如下格式),解析转换为pandas DataFrame

    +----+-------+
    | id | score |
    +----+-------+
    | 1  | 3.50  |
    | 2  | 3.65  |
    | 3  | 4.00  |
    | 4  | 3.85  |
    | 5  | 4.00  |
    | 6  | 3.65  |
    +----+-------+
    """

    table_str = table_str.strip()

    header = None
    data_lst = []
    for line in table_str.split('\\n'):
        if "+" in line: continue

        content = [token.strip() for token in line.strip(' |').split('|')]

        if header is None:
            header = content
            continue

        data_lst.append(content)
    
    # 进行数值格式转换
    result = pd.DataFrame(data_lst, columns=header)
    for series_name in pd.DataFrame(data_lst, columns=header):
        result[series_name] = pd.to_numeric(result[series_name], errors='ignore')
    return result

完整代码

以练习586. 订单最多的客户为例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import pandas as pd
from pandasql import sqldf

pysqldf = lambda q: sqldf(q, globals())

def clean_leetcode_table(table_str: str):
    """
    将Leetcode中出现的表(如下格式),解析转换为pandas DataFrame

    +----+-------+
    | id | score |
    +----+-------+
    | 1  | 3.50  |
    | 2  | 3.65  |
    | 3  | 4.00  |
    | 4  | 3.85  |
    | 5  | 4.00  |
    | 6  | 3.65  |
    +----+-------+
    """

    table_str = table_str.strip()

    header = None
    data_lst = []
    for line in table_str.split('\\n'):
        if "+" in line: continue

        content = [token.strip() for token in line.strip(' |').split('|')]

        if header is None:
            header = content
            continue

        data_lst.append(content)
    
    # 进行数值格式转换
    result = pd.DataFrame(data_lst, columns=header)
    for series_name in pd.DataFrame(data_lst, columns=header):
        result[series_name] = pd.to_numeric(result[series_name], errors='ignore')
    return result

table_str = """
+--------------+-----------------+
| order_number | customer_number |
+--------------+-----------------+
| 1            | 1               |
| 2            | 2               |
| 3            | 3               |
| 4            | 3               |
+--------------+-----------------+
"""

Orders = clean_leetcode_table(table_str)

SQL = """
SELECT
    customer_number
FROM
    Orders
GROUP BY 1
ORDER BY COUNT(*) DESC
LIMIT 1
"""

pysqldf(sql)

# 返回结果:
# | customer_number |
# | --------------- |
# | 3               |