“文本转SQL通俗的来说,就是用户输入自然语言,通过LLM大语言模型将自然语言结合表结构生成SQL的过程。”
上一篇文章我们介绍了Milvus向量数据库,它是用来存储向量数据的,我们将各种非结构化的文档转为向量,存储在向量数据库。但是RAG系统不只有非结构化的数据,也有结构化的数据,比如说存储在MySQL或者其他关系型数据库中的数据。如何在RAG系统里面对这类数据进行检索查询呢,这就是本文要分享的内容:文本转SQL``Text2SQL。
01 — 文本转SQL
文本转SQL通俗的来说,就是用户输入自然语言,通过LLM大语言模型将自然语言结合表结构生成SQL的过程。实现思路如下:
流程图
代码示例
import os from dotenv import load_dotenv load_dotenv() import pymysql import json # ====================== MYSQL的配置 ====================== MYSQL_CONFIG = { "host": "localhost", "port": 3306, "user": "root", "password": "root", "database": "finance_enterprise_db", "charset": "utf8mb4" } # 配置你的大模型API MIMO_API_KEY = os.getenv("M_PROXY_AI_API_KEY") MIMO_BASE_URL = os.getenv("M_PROXY_AI_BASE_URL") MIMO_MODEL = "mimo-v2.5-pro" # ====================== LLM 接口 ====================== def llm_generate_sql(prompt): """调用大模型生成SQL""" from openai import OpenAI # (兼容OpenAI协议) client = OpenAI( api_key=MIMO_API_KEY, base_url=MIMO_BASE_URL ) response = client.chat.completions.create( model=MIMO_MODEL, messages=[{"role": "user", "content": prompt}], temperature=0 # 生成SQL必须=0 ) sql = response.choices[0].message.content.strip() # 清理格式 sql = sql.replace("```sql", "").replace("```", "").strip() return sql # ====================== 工具函数 ====================== def get_db_connection(): """获取MySQL连接""" return pymysql.connect(**MYSQL_CONFIG, cursorclass=pymysql.cursors.DictCursor) def execute_sql(sql): """执行SQL,返回结果 or 异常""" conn = get_db_connection() try: with conn.cursor() as cursor: cursor.execute(sql) if sql.strip().upper().startswith("SELECT"): return cursor.fetchall(), None else: conn.commit() return "执行成功", None except Exception as e: return None, str(e) finally: conn.close() def get_all_table_info(): """获取数据库所有表名 + 字段结构 + 注释 + 示例SQL""" # 【稳定写法】直接用 INFORMATION_SCHEMA 查表名,100% 是字典 sql = f""" SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_SCHEMA = '{MYSQL_CONFIG["database"]}' """ tables, _ = execute_sql(sql) table_info_list = [] for row in tables: # 这里取到表名 table_name = row["TABLE_NAME"] # 表注释 comment_sql = f""" SELECT TABLE_COMMENT FROM information_schema.TABLES WHERE TABLE_SCHEMA = '{MYSQL_CONFIG["database"]}' AND TABLE_NAME = '{table_name}' """ table_comment, _ = execute_sql(comment_sql) table_comment = table_comment[0]["TABLE_COMMENT"] if table_comment else "" # 字段信息 columns, _ = execute_sql(f"DESC {table_name}") # 示例SQL example_sql = get_example_sql_by_table(table_name) # 拼接表结构描述 table_info = f""" 表名:{table_name} 表注释:{table_comment} 字段:{json.dumps(columns, ensure_ascii=False, indent=2)} 示例SQL: {example_sql} """ table_info_list.append(table_info) return "\n=====================\n".join(table_info_list) def get_example_sql_by_table(table_name): """根据表名返回业务示例SQL""" sql_map = { "fin_account_subject": "SELECT * FROM fin_account_subject LIMIT 5;", "fin_company_contact": "SELECT contact_name,contact_credit FROM fin_company_contact WHERE contact_type=1;", "fin_balance_sheet": "SELECT * FROM fin_balance_sheet WHERE report_year=2026 AND report_month=2;", "fin_profit_statement": "SELECT report_year,report_month,main_income,net_profit FROM fin_profit_statement ORDER BY report_year,report_month;", "fin_cash_flow": "SELECT * FROM fin_cash_flow WHERE report_year=2026 AND report_month=2;", "fin_expense_record": "SELECT expense_type,SUM(expense_amount) AS total FROM fin_expense_record GROUP BY expense_type;" } return sql_map.get(table_name, "SELECT * FROM " + table_name + " LIMIT 3;") # ====================== 核心流程:提问 → 生成SQL → 执行 → 重试 ====================== def text_to_sql_with_retry(question, max_retry=3): """ 流程: 1. 获取表结构 2. 拼接prompt给LLM 3. 生成SQL → 执行 4. 失败重试(最多3次) """ print(f"【用户问题】:{question}\n") # 1. 获取所有表结构+描述+示例SQL table_info = get_all_table_info() # 2. 构造Prompt prompt = f""" 你是专业SQL生成专家,请根据下面的数据库表结构、表注释、示例SQL, 严格按照用户问题生成【可直接运行】的MySQL语句,只返回SQL,不要任何解释。 表结构信息: {table_info} 用户问题:{question} 要求: 1. 只返回SQL,不要```、不要文字 2. 必须使用提供的表和字段 3. 日期、金额、分组、排序严格按业务逻辑 4. 字段名不要编造 """ # 3. 重试机制 for retry in range(max_retry): print(f"=== 第 {retry+1} 次生成SQL ===") try: # 小米 MiMo 生成SQL sql = llm_generate_sql(prompt) print(f"【生成SQL】:\n{sql}\n") # 执行SQL result, error = execute_sql(sql) if not error: print("【执行成功】") return { "question": question, "sql": sql, "result": result, "error": None } else: print(f"【执行失败】:{error}") # 把错误追加到prompt,让LLM修正 prompt += f"\n\n上一次生成的SQL执行报错:{error},请修正SQL,只返回正确SQL" except Exception as e: print(f"【重试异常】{str(e)}") return {"error": "重试次数耗尽,生成SQL失败"} # ====================== 测试 ====================== if __name__ == "__main__": # 你可以随便换问题 question = "查询2026年1月和2月的主营业务收入、净利润,按月份排序" # 执行 result = text_to_sql_with_retry(question, max_retry=3) # 打印最终结果 print("\n" + "=" * 50) print("最终结果:") print(json.dumps(result, ensure_ascii=False, indent=2))02 — 优化版本
我们通过LLM理解数据库的表结构来生成SQL语句,拿生成的SQL查询数据并返回结果。这逻辑本身是没问题的,但实际生产环境我们的数据库表很多,不可能一下全部查询出来放到prompt中,这样做既不合理,生成SQL也不够精准。所以我们可以引入上文所说的Milvus向量数据库,我们把数据库的表结构、字段注释、实例SQL向量化放入Milvus向量数据库。将用户的问题向量化,用Milvus向量检索出我们需要的表信息和相关实例SQL。
流程图
代码示例
初始化:
import os import openai from dotenv import load_dotenv load_dotenv() from pymilvus import MilvusClient, DataType from milvus_model.dense import OpenAIEmbeddingFunction from sqlalchemy import create_engine, text load_dotenv() openai.api_key = os.getenv("OPENAI_API_KEY") # 修复:使用新包 embedding_fn = OpenAIEmbeddingFunction( model_name='text-embedding-3-large', api_key=openai.api_key ) # 修复:不能用文件模式,必须连接 Milvus 服务 client = MilvusClient(uri="http://localhost:19530") DB_URL = "mysql+pymysql://root:root@localhost:3306/finance_enterprise_db" engine = create_engine(DB_URL) # ====================== 创建3个集合 ====================== def create_collections(): for name, dim in [("ddl_knowledge", 3072), ("dbdesc_knowledge", 3072), ("q2sql_knowledge", 3072)]: if client.has_collection(name): client.drop_collection(name) schema = client.create_schema() schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True, auto_id=True) schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=8192) schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dim) index = client.prepare_index_params() index.add_index(field_name="vector", index_type="AUTOINDEX", metric_type="COSINE") client.create_collection(name, schema=schema, index_params=index) # ====================== 获取所有表结构 ====================== def get_all_tables(): with engine.connect() as conn: tables = conn.execute(text("SHOW TABLES")).fetchall() return [t[0] for t in tables] def get_ddl(table): return f"CREATE TABLE {table} (...)" def get_columns(table): return [{"col": "id", "desc": "主键"}, {"col": "name", "desc": "名称"}] def get_examples(): return [ ("查询2026年1-2月净利润", "SELECT * FROM fin_profit_statement WHERE report_year=2026 AND report_month IN (1,2)"), ("查询各类型支出总额", "SELECT expense_type,SUM(expense_amount) FROM fin_expense_record GROUP BY expense_type") ] # ====================== 入库 ====================== def build_all(): create_collections() tables = get_all_tables() print("表:", tables) print("Milvus 财务库构建完成") if __name__ == "__main__": build_all()文本转SQL:
import os import logging import re import openai from dotenv import load_dotenv from pymilvus import MilvusClient from milvus_model.dense import OpenAIEmbeddingFunction from sqlalchemy import create_engine, text # ====================== 基础配置 ====================== logging.basicConfig(level=logging.INFO, format='[%(levelname)s] %(message)s') load_dotenv() # ====================== OpenAI ====================== openai.api_key = os.getenv("OPENAI_API_KEY") MODEL_NAME = os.getenv("OPENAI_MODEL", "gpt-4o-mini") # ====================== 嵌入函数 ====================== embedding_fn = OpenAIEmbeddingFunction( model_name='text-embedding-3-large', api_key=openai.api_key ) # ====================== 连接 Milvus 服务 ====================== MILVUS_DB = "http://localhost:19530" client = MilvusClient(MILVUS_DB) # ====================== 财务数据库连接 ====================== DB_URL = "mysql+pymysql://root:root@localhost:3306/finance_enterprise_db" engine = create_engine(DB_URL) # ====================== 向量检索工具 ====================== def retrieve(collection: str, query_emb: list, top_k=3, fields=None): results = client.search( collection_name=collection, data=[query_emb], limit=top_k, output_fields=fields ) return results[0] if results else [] # ====================== SQL 提取 ====================== def extract_sql(text: str) -> str: sql_blocks = re.findall(r'```sql\n(.*?)\n```', text, re.DOTALL) if sql_blocks: return sql_blocks[0].strip() select_match = re.search(r'SELECT.*?;', text, re.DOTALL) if select_match: return select_match.group(0).strip() return text.strip() # ====================== SQL 执行 ====================== def execute_sql(sql: str): try: with engine.connect() as conn: result = conn.execute(text(sql)) cols = result.keys() rows = result.fetchall() return True, cols, rows except Exception as e: return False, None, str(e) # ====================== LLM 生成 SQL ====================== def generate_sql(prompt: str, error_msg=None): if error_msg: prompt += f"\n上一次执行报错:{error_msg},请修正SQL,只返回正确SQL语句" response = openai.chat.completions.create( model=MODEL_NAME, messages=[{"role": "user", "content": prompt}] ) raw = response.choices[0].message.content.strip() sql = extract_sql(raw) logging.info(f"生成SQL: {sql}") return sql # ====================== 核心财务库 Text2SQL ====================== def text2sql_finance(question: str, max_retries=3): # 1. 问题向量化 q_emb = embedding_fn([question])[0] # 2. 三大向量检索 ddl_hits = retrieve("ddl_knowledge", q_emb, top_k=3, fields=["text"]) q2sql_hits = retrieve("q2sql_knowledge", q_emb, top_k=3, fields=["question", "sql_text"]) desc_hits = retrieve("dbdesc_knowledge", q_emb, top_k=5, fields=["table_name", "column_name", "description"]) # 3. 拼接上下文 ddl_text = "\n".join([h.get("text", "") for h in ddl_hits]) example_text = "\n".join([f"问题:{h['question']}\nSQL:{h['sql_text']}" for h in q2sql_hits]) desc_text = "\n".join([f"{h['table_name']}.{h['column_name']}:{h['description']}" for h in desc_hits]) # 4. 构造Prompt prompt = f""" 你是财务SQL专家,请根据下面的表结构、字段说明、示例,生成可直接运行的MySQL语句,只返回SQL。 ### 表结构: {ddl_text} ### 字段含义: {desc_text} ### 参考示例: {example_text} ### 用户问题: {question} 要求: 1. 只返回SQL,不要任何解释 2. 字段必须真实存在 3. 日期、分组、排序严格按财务逻辑 4. 不要编造字段 """ # 5. 重试执行 last_err = None for i in range(max_retries): logging.info(f"第 {i+1} 次生成") sql = generate_sql(prompt, last_err) ok, cols, res = execute_sql(sql) if ok: print("\n执行成功") print("字段:", cols) for row in res: print(row) return last_err = res logging.error(f"失败:{last_err}") print("超过最大重试次数") print("最后错误:", last_err) # ====================== 测试 ====================== if __name__ == "__main__": q = input("请输入财务查询问题:") text2sql_finance(q)学AI大模型的正确顺序,千万不要搞错了
🤔2026年AI风口已来!各行各业的AI渗透肉眼可见,超多公司要么转型做AI相关产品,要么高薪挖AI技术人才,机遇直接摆在眼前!
有往AI方向发展,或者本身有后端编程基础的朋友,直接冲AI大模型应用开发转岗超合适!
就算暂时不打算转岗,了解大模型、RAG、Prompt、Agent这些热门概念,能上手做简单项目,也绝对是求职加分王🔋
📝给大家整理了超全最新的AI大模型应用开发学习清单和资料,手把手帮你快速入门!👇👇
学习路线:
✅大模型基础认知—大模型核心原理、发展历程、主流模型(GPT、文心一言等)特点解析
✅核心技术模块—RAG检索增强生成、Prompt工程实战、Agent智能体开发逻辑
✅开发基础能力—Python进阶、API接口调用、大模型开发框架(LangChain等)实操
✅应用场景开发—智能问答系统、企业知识库、AIGC内容生成工具、行业定制化大模型应用
✅项目落地流程—需求拆解、技术选型、模型调优、测试上线、运维迭代
✅面试求职冲刺—岗位JD解析、简历AI项目包装、高频面试题汇总、模拟面经
以上6大模块,看似清晰好上手,实则每个部分都有扎实的核心内容需要吃透!
我把大模型的学习全流程已经整理📚好了!抓住AI时代风口,轻松解锁职业新可能,希望大家都能把握机遇,实现薪资/职业跃迁~