-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_eval.py
More file actions
159 lines (125 loc) · 5.79 KB
/
train_eval.py
File metadata and controls
159 lines (125 loc) · 5.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import time
import numpy as np
import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from tqdm import tqdm
class SetTarget(object):
"""
This transform modifies the labels vector per data sample to only keep
the label for a specific target (there are 19 targets in QM9).
"""
def __init__(self, target):
super(SetTarget, self).__init__()
self.target = target
def __call__(self, data):
data.y = data.y[:, self.target]
return data
def prepare_data_qm9(reduced_set=True):
path = './qm9'
target = 7
# Transforms which are applied during data loading:
# (1) Fully connect the graphs, (2) Select the target/label
transform = T.Compose([SetTarget(target=target)])
# Load the QM9 dataset with the transforms defined
dataset = QM9(path, transform=transform)
# Normalize targets per data sample to mean = 0 and std = 1.
# mean = dataset.data.y.mean(dim=0, keepdim=True)
# std = dataset.data.y.std(dim=0, keepdim=True)
# dataset.data.y = (dataset.data.y - mean) / std
# mean, std = mean[:, target].item(), std[:, target].item()
# split according to paper's spec
train_dataset = dataset[:110_000]
val_dataset = dataset[110_000:120_000]
test_dataset = dataset[120_000:]
batch_size = 32
# split according to hardware friendly spec
if reduced_set:
train_dataset = dataset[:1000]
val_dataset = dataset[1000:2000]
test_dataset = dataset[2000:3000]
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
return train_loader, val_loader, test_loader
def extract_max_z():
dataset = QM9('./qm9')
return dataset.z.max()
def train(model, train_loader, optimizer, device):
model.train()
loss_all = 0
for data in train_loader:
data = data.to(device)
optimizer.zero_grad()
y_pred = model(data)
loss = F.mse_loss(y_pred, data.y)
loss.backward()
loss_all += loss.item() * data.num_graphs
optimizer.step()
return loss_all / len(train_loader.dataset)
def eval(model, loader, device):
model.eval()
error = 0
for data in loader:
data = data.to(device)
with torch.no_grad():
y_pred = model(data)
# Mean Absolute Error, std not required as y is normalized
assert data.y.shape == y_pred.shape, 'Shapes do not match, if they differ, the loss calculation often does weird things'
error += (y_pred - data.y).abs().sum().item()
return error / len(loader.dataset)
def run_experiment(model, model_name, train_loader, val_loader, test_loader, n_epochs=100, patience=10):
print(f"Running experiment for {model_name}, training on {len(train_loader.dataset)} samples for {n_epochs} epochs.")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("\nModel architecture:")
print(model)
total_param = 0
for param in model.parameters():
total_param += np.prod(list(param.data.size()))
print(f'Total parameters: {total_param}')
model = model.to(device)
# Adam optimizer with LR 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# LR scheduler which decays LR when validation metric doesn't improve
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.8, patience=15, min_lr=1e-7)
print("\nStart training:")
best_val_error = None
patience_counter = 0
perf_per_epoch = [] # Track Test/Val MAE vs. epoch (for plotting)
t = time.time()
state_dict = None
with tqdm(total=n_epochs, desc='Training model...') as bar:
with tqdm(bar_format='{desc}') as line2:
for epoch in range(1, n_epochs+1):
# Call LR scheduler at start of each epoch
lr = scheduler.optimizer.param_groups[0]['lr']
# Train model for one epoch, return avg. training loss
loss = train(model, train_loader, optimizer, device)
# Evaluate model on validation set
val_error = eval(model, val_loader, device)
if best_val_error is None or val_error <= best_val_error:
# Evaluate model on test set if validation metric improves
test_error = eval(model, test_loader, device)
state_dict = model.state_dict()
best_val_error = val_error
patience_counter = 0
else:
patience_counter += 1
if patience_counter >= patience:
print(f"Early stopping at epoch {epoch}, best validation MAE: {best_val_error:.7f}, corresponding test MAE: {test_error:.7f}.")
break
scheduler.step(val_error)
perf_per_epoch.append((test_error, val_error, epoch, model_name))
bar.update()
line2.set_description(f'Epoch {epoch}/{n_epochs}: '
f'LR: {lr:.1e}, Patience: {patience_counter}/{patience}, '
f'Loss: {loss:.3f}, Val MAE: {val_error:.3f}, '
f'Best Val MAE: {best_val_error:.3f}, Test MAE: {test_error:.3f}')
model.load_state_dict(state_dict)
t = time.time() - t
train_time = t/60
print(f"\nDone! Training took {train_time:.2f} mins. Best validation MAE: {best_val_error:.7f}, corresponding test MAE: {test_error:.7f}.")
return best_val_error, test_error, train_time, perf_per_epoch