用Python玩转SAM模型:零基础打造智能抠图工具
设计师朋友小李最近接了个电商产品图精修的单子,客户发来的原始照片背景杂乱无章。传统钢笔工具抠图需要反复调整锚点,一张图就要耗费半小时。直到他发现了Meta发布的Segment Anything模型——这个能通过简单点击自动识别物体的AI工具,让他的工作效率提升了十倍。本文将带你从零开始,用Python代码将SAM模型变成你的私人抠图助手。
1. 环境配置与模型准备
1.1 基础环境搭建
首先需要准备Python 3.8+环境,推荐使用Anaconda创建独立环境避免依赖冲突:
conda create -n sam_env python=3.9 conda activate sam_env安装核心依赖库时要注意版本兼容性:
pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117 pip install opencv-python matplotlib numpy segment-anything提示:CUDA版本需要与显卡驱动匹配,NVIDIA用户可通过
nvidia-smi查看支持的CUDA版本
1.2 模型文件获取
SAM提供三种预训练模型,根据硬件条件选择:
| 模型类型 | 参数量 | 显存占用 | 适用场景 |
|---|---|---|---|
| vit_h | 636M | 8GB+ | 高精度专业级 |
| vit_l | 308M | 4-6GB | 平衡型 |
| vit_b | 91M | 2-3GB | 快速测试 |
下载模型权重文件后(如sam_vit_b_01ec64.pth),建议存放在项目根目录的models文件夹中。
2. 基础抠图功能实现
2.1 单点精准抠图
以下代码实现点击图片某处自动抠出目标物体:
import cv2 import numpy as np from segment_anything import sam_model_registry, SamPredictor def init_sam(model_path="models/sam_vit_b_01ec64.pth"): sam = sam_model_registry["vit_b"](checkpoint=model_path) sam.to(device="cuda") return SamPredictor(sam) def single_point_cutout(image_path, point_x, point_y): image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB) predictor = init_sam() predictor.set_image(image) input_point = np.array([[point_x, point_y]]) input_label = np.array([1]) # 1表示前景点 masks, _, _ = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=True ) return masks[0] # 返回置信度最高的掩膜实际应用时,可以通过OpenCV的鼠标回调实现交互式点选:
def on_mouse_click(event, x, y, flags, param): if event == cv2.EVENT_LBUTTONDOWN: mask = single_point_cutout("product.jpg", x, y) show_result(mask)2.2 框选批量抠图
对于轮廓清晰的物体,矩形框选效率更高:
def box_cutout(image_path, x1, y1, x2, y2): image = cv2.imread(image_path) predictor = init_sam() predictor.set_image(image) input_box = np.array([x1, y1, x2, y2]) masks, _, _ = predictor.predict( box=input_box[None, :], multimask_output=False ) return masks[0]电商产品图中常见的多物体抠图场景:
# 同时抠取图片中的手机和耳机 boxes = [ [120, 80, 300, 400], # 手机坐标 [350, 200, 450, 350] # 耳机坐标 ] combined_mask = np.zeros_like(image) for box in boxes: mask = box_cutout("electronics.jpg", *box) combined_mask = np.logical_or(combined_mask, mask)3. 高级抠图技巧
3.1 复杂场景精修
当目标物体与背景颜色相近时,可以结合正负点提示:
def refine_cutout(image_path, box, positive_points, negative_points): image = cv2.imread(image_path) predictor = init_sam() predictor.set_image(image) # 合并正负点 all_points = np.array(positive_points + negative_points) all_labels = np.array([1]*len(positive_points) + [0]*len(negative_points)) masks, _, _ = predictor.predict( point_coords=all_points, point_labels=all_labels, box=np.array(box), multimask_output=False ) return masks[0]示例:抠取玻璃杯中的液体
glass_box = [200, 150, 400, 500] liquid_points = [[300, 300], [320, 280]] # 液体区域点 frame_points = [[280, 180]] # 杯框干扰点 mask = refine_cutout("glass.jpg", glass_box, liquid_points, frame_points)3.2 批量自动化处理
结合对象检测模型实现全自动流水线处理:
from detectron2 import model_zoo from detectron2.engine import DefaultPredictor def auto_cutout_pipeline(image_path): # 第一步:检测物体 detector = DefaultPredictor(model_zoo.get("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")) detections = detector(cv2.imread(image_path)) # 第二步:SAM抠图 sam_predictor = init_sam() image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB) sam_predictor.set_image(image) results = [] for box in detections['instances'].pred_boxes: masks, _, _ = sam_predictor.predict( box=box.cpu().numpy()[None, :], multimask_output=False ) results.append(masks[0]) return results4. 实用化功能扩展
4.1 透明背景生成
将抠图结果保存为PNG透明图片:
def save_transparent(image_path, mask, output_path): image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) if image.shape[2] == 3: image = cv2.cvtColor(image, cv2.COLOR_BGR2BGRA) image[:, :, 3] = mask * 255 # 设置alpha通道 cv2.imwrite(output_path, image)4.2 背景替换合成
实现电商常见的场景切换效果:
def change_background(orig_path, mask, bg_path, output_path): foreground = cv2.imread(orig_path) background = cv2.imread(bg_path) # 调整背景尺寸匹配前景 background = cv2.resize(background, (foreground.shape[1], foreground.shape[0])) # 合成图像 composite = np.where(mask[..., None], foreground, background) cv2.imwrite(output_path, composite)4.3 批量处理工具开发
用PyQt构建可视化操作界面:
from PyQt5.QtWidgets import (QApplication, QMainWindow, QFileDialog, QGraphicsScene) class SAMEditor(QMainWindow): def __init__(self): super().__init__() self.init_ui() self.sam_predictor = init_sam() def open_image(self): path, _ = QFileDialog.getOpenFileName() self.image = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB) self.display_image(self.image) def mousePressEvent(self, event): if hasattr(self, 'image'): x, y = event.pos().x(), event.pos().y() mask = self.single_point_cutout(x, y) self.show_mask(mask)实际测试中发现,对于毛绒玩具等边缘模糊的物体,适当增加负样本点(背景点)能显著提升分割精度。而在处理金属反光物体时,矩形框选比点选效果更稳定。