1234523767123
立即下载
资源介绍:
1234523767123
# train.py
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.preprocessing import OneHotEncoder
from model import FCNet
import os
import pickle
import sklearn
from packaging import version
# 设置随机种子以确保结果可复现
torch.manual_seed(42)
np.random.seed(42)
# 自定义数据集
class AminoAcidDataset(Dataset):
def __init__(self, single_csv, double_csv, triple_csv):
# 读取 CSV 文件,指定没有表头并手动命名列
single = pd.read_csv(single_csv, header=None, names=['AminoAcid', 'Value'])
double = pd.read_csv(double_csv, header=None, names=['AminoAcidPair', 'Value1', 'Value2'])
triple = pd.read_csv(triple_csv, header=None, names=['AminoAcidTriple', 'Value1', 'Value2'])
# 创建氨基酸到数值的映射字典
self.single_dict = single.set_index('AminoAcid')['Value'].to_dict()
# 获取所有独热编码的氨基酸,加上一个填充字符 '-'
self.amino_acids = sorted(self.single_dict.keys()) + ['-']
# 根据 scikit-learn 版本选择参数
if version.parse(sklearn.__version__) >= version.parse("1.2"):
self.encoder = OneHotEncoder(sparse_output=False) # 新版本使用 sparse_output
else:
self.encoder = OneHotEncoder(sparse=False) # 旧版本使用 sparse
self.encoder.fit(np.array(self.amino_acids).reshape(-1, 1))
# 准备数据
self.X = []
self.y = []
# 处理单个氨基酸,补齐为三个位置
for index, row in single.iterrows():
aa = row['AminoAcid']
if aa not in self.single_dict:
print(f"警告:氨基酸 '{aa}' 不在单个氨基酸字典中,已跳过。")
continue
# 填充为三个位置,例如 'A--'
aa_triplet = aa + '--'
try:
one_hot1 = self.encoder.transform([[aa_triplet[0]]])[0]
one_hot2 = self.encoder.transform([[aa_triplet[1]]])[0]
one_hot3 = self.encoder.transform([[aa_triplet[2]]])[0]
except Exception as e:
print(f"错误:无法对氨基酸 '{aa_triplet}' 进行编码,已跳过。错误信息:{e}")
continue
value1_aa1 = self.single_dict.get(aa_triplet[0], 0.0)
value1_aa2 = self.single_dict.get(aa_triplet[1], 0.0)
value1_aa3 = self.single_dict.get(aa_triplet[2], 0.0)
feature_part1 = np.concatenate([one_hot1, [value1_aa1]])
feature_part2 = np.concatenate([one_hot2, [value1_aa2]])
feature_part3 = np.concatenate([one_hot3, [value1_aa3]])
feature = np.stack([feature_part1, feature_part2, feature_part3]) # (3, N + 1)
self.X.append(feature)
self.y.append([float(row['Value']), 0.0]) # 对于单个氨基酸,第二个值设为0.0
# 处理双氨基酸组合,补齐为三个位置
for index, row in double.iterrows():
aa_pair = row['AminoAcidPair']
if not isinstance(aa_pair, str):
print(f"警告:第 {index} 行的氨基酸对不是字符串,值为 '{aa_pair}',已跳过。")
continue
if len(aa_pair) != 2:
print(f"警告:第 {index} 行的氨基酸对 '{aa_pair}' 不包含两个氨基酸,已跳过。")
continue
aa1, aa2 = aa_pair[0], aa_pair[1]
if aa1 not in self.single_dict or aa2 not in self.single_dict:
print(f"警告:氨基酸 '{aa1}' 或 '{aa2}' 不在单个氨基酸字典中,已跳过。")
continue
# 填充为三个位置,例如 'AC-'
aa_triplet = aa1 + aa2 + '-'
try:
one_hot1 = self.encoder.transform([[aa_triplet[0]]])[0]
one_hot2 = self.encoder.transform([[aa_triplet[1]]])[0]
one_hot3 = self.encoder.transform([[aa_triplet[2]]])[0]
except Exception as e:
print(f"错误:无法对氨基酸 '{aa_triplet}' 进行编码,已跳过。错误信息:{e}")
continue
value1_aa1 = self.single_dict.get(aa_triplet[0], 0.0)
value1_aa2 = self.single_dict.get(aa_triplet[1], 0.0)
value1_aa3 = self.single_dict.get(aa_triplet[2], 0.0)
feature_part1 = np.concatenate([one_hot1, [value1_aa1]])
feature_part2 = np.concatenate([one_hot2, [value1_aa2]])
feature_part3 = np.concatenate([one_hot3, [value1_aa3]])
feature = np.stack([feature_part1, feature_part2, feature_part3]) # (3, N + 1)
self.X.append(feature)
self.y.append([float(row['Value1']), float(row['Value2'])])
# 处理三个氨基酸组合
for index, row in triple.iterrows():
aa_triple = row['AminoAcidTriple']
if not isinstance(aa_triple, str):
print(f"警告:第 {index} 行的氨基酸三联不是字符串,值为 '{aa_triple}',已跳过。")
continue
if len(aa_triple) != 3:
print(f"警告:第 {index} 行的氨基酸三联 '{aa_triple}' 不包含三个氨基酸,已跳过。")
continue
aa1, aa2, aa3 = aa_triple[0], aa_triple[1], aa_triple[2]
if aa1 not in self.single_dict or aa2 not in self.single_dict or aa3 not in self.single_dict:
print(f"警告:氨基酸 '{aa1}', '{aa2}' 或 '{aa3}' 不在单个氨基酸字典中,已跳过。")
continue
try:
one_hot1 = self.encoder.transform([[aa1]])[0]
one_hot2 = self.encoder.transform([[aa2]])[0]
one_hot3 = self.encoder.transform([[aa3]])[0]
except Exception as e:
print(f"错误:无法对氨基酸 '{aa1}{aa2}{aa3}' 进行编码,已跳过。错误信息:{e}")
continue
value1_aa1 = self.single_dict.get(aa1, 0.0)
value1_aa2 = self.single_dict.get(aa2, 0.0)
value1_aa3 = self.single_dict.get(aa3, 0.0)
feature_part1 = np.concatenate([one_hot1, [value1_aa1]])
feature_part2 = np.concatenate([one_hot2, [value1_aa2]])
feature_part3 = np.concatenate([one_hot3, [value1_aa3]])
feature = np.stack([feature_part1, feature_part2, feature_part3]) # (3, N + 1)
self.X.append(feature)
self.y.append([float(row['Value1']), float(row['Value2'])])
self.X = np.array(self.X, dtype=np.float32) # 形状:(样本数, 3, N + 1)
self.y = np.array(self.y, dtype=np.float32) # 形状:(样本数, 2)
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
return self.X[idx], self.y[idx]
def train():
# 数据文件路径
single_csv = '../data/single.csv'
double_csv = '../data/double.csv'
triple_csv = '../data/triple.csv'
# 检查文件是否存在
if not os.path.exists(single_csv):
print(f"错误:单个氨基酸数据文件 '{single_csv}' 不存在。")
return
if not os.path.exists(double_csv):
print(f"错误:组合氨基酸数据文件 '{double_csv}' 不存在。")
return
if not os.path.exists(triple_csv):
print(f"错误:三联氨基酸数据文件 '{triple_csv}' 不存在。")
return
# 创建数据集
dataset = AminoAcidDataset(single_csv, double_csv, triple_csv)
print(f"总样本数量:{len(dataset)}")
if len(dataset) == 0:
print("错误:数据集中没有有效的样本。请检查数据文件的内容和格式。")
return
# 数据集划分:80% 训练,20% 验证
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = rando
资源文件列表:
adsorption_proj_1+2--2_version4.0 (perfect version).zip 大约有104个文件