相关项目下载链接
训练框架
在开始实现相应模块功能之前,首先熟悉训练框架·train.py。
1. 导入与模型字典构建
importinspectimportmathfromdatetimeimportdatetimefrompathlibimportPathimporttorchimportae,autoregressive,bsq# 自定义模型模块(AE/BSQ/自回归)# 收集ae/bsq模块中所有继承nn.Module的块级模型类patch_models={n:mforMin[ae,bsq]forn,mininspect.getmembers(M)ifinspect.isclass(m)andissubclass(m,torch.nn.Module)}# 收集autoregressive模块中所有继承nn.Module的自回归模型类ar_models={n:mforMin[autoregressive]forn,mininspect.getmembers(M)ifinspect.isclass(m)andissubclass(m,torch.nn.Module)}2. 核心训练函数 train()
共包含三个部分:块级模型训练器 PatchTrainer、自回归模型训练器 AutoregressiveTrainer、模型保存回调 CheckPointer。
其中,
- 块级模型训练器 PatchTrainer专用于 AE/BSQ 模型。值得注意的是数据预处理过程,图像归一化的方式是(/255.0 - 0.5),将像素值映射到[-0.5, 0.5]而不是[0,1];损失函数采用MSE(均方误差),适配图像重构任务;优化器为AdamW,学习率1e-3;基于
ImageDataset加载原始图像数据集。 - 自回归模型训练器 AutoregressiveTrainer专用于 AR 模型。使用交叉熵损失,适配令牌序列的分类预测任务;基于
TokenDataset加载令牌化后的图像序列;优化器为AdamW,学习率1e-3。 - 模型保存回调 CheckPointer。模型保存的触发时机是在每个训练 epoch 结束后;有两种保存方式:一种是带时间戳的模型,保存方式为
checkpoints/{时间戳}_{模型名}.pth;另一种是最新的模型,保存路径为当前目录下的{模型名}.pth。
此外,还实现了模型加载 / 创建的逻辑。
deftrain(model_name_or_path:str,epochs:int=5,batch_size:int=64):importlightningasLfromlightning.pytorch.loggersimportTensorBoardLoggerfromdataimportImageDataset,TokenDatasetclassPatchTrainer(L.LightningModule):def__init__(self,model):super().__init__()self.model=modeldeftraining_step(self,x,batch_idx):x=x.float()/255.0-0.5x_hat,additional_losses=self.model(x)loss=torch.nn.functional.mse_loss(x_hat,x)self.log("train/loss",loss,prog_bar=True)fork,vinadditional_losses.items():self.log(f"train/{k}",v)returnloss+sum(additional_losses.values())defvalidation_step(self,x,batch_idx):x=x.float()/255.0-0.5withtorch.no_grad():x_hat,additional_losses=self.model(x)loss=torch.nn.functional.mse_loss(x_hat,x)self.log("validation/loss",loss,prog_bar=True)fork,vinadditional_losses.items():self.log(f"validation/{k}",v)ifbatch_idx==0:self.logger.experiment.add_images("input",(x[:64]+0.5).clamp(min=0,max=1).permute(0,3,1,2),self.global_step)self.logger.experiment.add_images("prediction",(x_hat[:64]+0.5).clamp(min=0,max=1).permute(0,3,1,2),self.global_step)returnlossdefconfigure_optimizers(self):returntorch.optim.AdamW(self.parameters(),lr=1e-3)deftrain_dataloader(self):dataset=ImageDataset("train")returntorch.utils.data.DataLoader(dataset,batch_size=batch_size,num_workers=4,shuffle=True)defval_dataloader(self):dataset=ImageDataset("valid")returntorch.utils.data.DataLoader(dataset,batch_size=4096,num_workers=4,shuffle=True)classAutoregressiveTrainer(L.LightningModule):def__init__(self,model):super().__init__()self.model=modeldeftraining_step(self,x,batch_idx):x_hat,additional_losses=self.model(x)loss=(torch.nn.functional.cross_entropy(x_hat.view(-1,x_hat.shape[-1]),x.view(-1),reduction="sum")/math.log(2)/x.shape[0])self.log("train/loss",loss,prog_bar=True)fork,vinadditional_losses.items():self.log(f"train/{k}",v)returnloss+sum(additional_losses.values())defvalidation_step(self,x,batch_idx):withtorch.no_grad():x_hat,additional_losses=self.model(x)loss=(torch.nn.functional.cross_entropy(x_hat.view(-1,x_hat.shape[-1]),x.view(-1),reduction="sum")/math.log(2)/x.shape[0])self.log("validation/loss",loss,prog_bar=True)fork,vinadditional_losses.items():self.log(f"validation/{k}",v)returnlossdefconfigure_optimizers(self):returntorch.optim.AdamW(self.parameters(),lr=1e-3)deftrain_dataloader(self):dataset=TokenDataset("train")returntorch.utils.data.DataLoader(dataset,batch_size=batch_size,num_workers=4,shuffle=True)defval_dataloader(self):dataset=TokenDataset("valid")returntorch.utils.data.DataLoader(dataset,batch_size=batch_size,num_workers=4,shuffle=True)classCheckPointer(L.Callback):defon_train_epoch_end(self,trainer,pl_module):fn=Path(f"checkpoints/{timestamp}_{model_name}.pth")fn.parent.mkdir(exist_ok=True,parents=True)torch.save(model,fn)torch.save(model,Path(__file__).parent/f"{model_name}.pth")# Load or create the modelifPath(model_name_or_path).exists():model=torch.load(model_name_or_path,weights_only=False)model_name=model.__class__.__name__else:model_name=model_name_or_pathifmodel_nameinpatch_models:model=patch_models[model_name]()elifmodel_nameinar_models:model=ar_models[model_name]()else:raiseValueError(f"Unknown model:{model_name}")# Create the lightning modelifisinstance(model,(autoregressive.Autoregressive)):l_model=AutoregressiveTrainer(model)else:l_model=PatchTrainer(model)timestamp=datetime.now().strftime("%Y-%m-%d_%H-%M-%S")logger=TensorBoardLogger("logs",name=f"{timestamp}_{model_name}")trainer=L.Trainer(max_epochs=epochs,logger=logger,callbacks=[CheckPointer()])trainer.fit(model=l_model,)3. 命令行启动
本项目借助fire库实现命令行参数解析,无需手动解析--epochs/--batch_size等参数,直接通过python train.py {模型名} --epochs 10启动训练。
if__name__=="__main__":fromfireimportFire Fire(train)train.py核心使用方法如下:
# 训练块级自编码器python train.py PatchAutoEncoder--epochs5--batch_size64# 训练自回归模型python train.py AutoregressiveModel--epochs10--batch_size32# 加载已有模型续训python train.py checkpoints/2025-10-20_PatchAutoEncoder.pth--epochs10加载数据
接下来熟悉这个项目是如何进行数据加载的,data.py模块定义两类 PyTorch 兼容的数据集类。
其中:
ImageDataset:加载原始 JPG 图像,提供缓存机制提升读取效率;TokenDataset:加载令牌化后的图像张量(由tokenize.py生成),供自回归模型训练使用。。
1. 导入依赖库
frompathlibimportPathimporttorchfromPILimportImage# 自动定位数据集根目录:当前文件的父父目录下的data文件夹DATASET_PATH=Path(__file__).parent.parent/"data"2. ImageDataset(原始图像数据集)
classImageDataset:def__init__(self,split:str,cache_images:bool=True):# 收集split(train/valid)目录下所有.jpg文件路径self.image_paths=list((DATASET_PATH/split).rglob("*.jpg"))# 初始化图像缓存列表,避免重复读取磁盘self._image_cache=[None]*len(self.image_paths)self._cache_images=cache_images# 是否开启缓存def__len__(self)->int:returnlen(self.image_paths)# 数据集总长度def__getitem__(self,idx:int)->torch.Tensor:# 优先读取缓存,无缓存则加载图像ifself._image_cache[idx]isnotNone:returnself._image_cache[idx]# 图像加载:PIL打开→转numpy数组→转torch.uint8张量(保持原始像素值)img=torch.tensor(np.array(Image.open(self.image_paths[idx])),dtype=torch.uint8)# 开启缓存则存入,后续复用ifself._cache_images:self._image_cache[idx]=imgreturnimg3. TokenDataset(令牌化数据集)
classTokenDataset(torch.utils.data.TensorDataset):def__init__(self,split:str):# 加载令牌化后的张量文件(由tokenize.py生成)tensor_path=DATASET_PATH/f"tokenized_{split}.pth"ifnottensor_path.exists():# 文件不存在时给出明确提示,符合作业流程指引raiseFileNotFoundError(f"Tokenized dataset not found at{tensor_path}...")self.data=torch.load(tensor_path,weights_only=False)def__getitem__(self,idx:int)->torch.Tensor:# 返回长整型张量(适配自回归模型的离散令牌输入)returntorch.tensor(self.data[idx],dtype=torch.long)def__len__(self)->int:returnlen(self.data)这两个数据集加载对象的使用方法如下所示:
# 加载训练集原始图像(用于AE/BSQ训练)fromdataimportImageDataset,TokenDataset train_img_ds=ImageDataset("train",cache_images=True)img_tensor=train_img_ds[0]# 取第0张图像,shape: (H, W, 3)# 加载训练集令牌数据(用于自回归模型训练)train_token_ds=TokenDataset("train")token_tensor=train_token_ds[0]# 取第0个令牌序列,shape: (序列长度,)# 配合DataLoader使用fromtorch.utils.dataimportDataLoader train_loader=DataLoader(train_token_ds,batch_size=64,shuffle=True)