-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain_tsptw.py
More file actions
56 lines (45 loc) · 2.09 KB
/
Copy pathtrain_tsptw.py
File metadata and controls
56 lines (45 loc) · 2.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# train_tsptw.py
import torch
import argparse
from models.tsptw_gnn_model import TSPTWGraphEncoder
from models.tsptw_trainer import TSPTWTrainer, TSPTWDataset
from data.tsptw_generator import TSPTWDataGenerator
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--n_nodes', type=int, default=20)
parser.add_argument('--grid_size', type=int, default=100)
parser.add_argument('--n_train', type=int, default=2000)
parser.add_argument('--n_val', type=int, default=400)
parser.add_argument('--epochs', type=int, default=30)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--embedding_dim', type=int, default=128)
parser.add_argument('--n_layers', type=int, default=5)
parser.add_argument('--seed', type=int, default=1234)
return parser.parse_args()
def main():
args = parse_args()
torch.manual_seed(args.seed)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device:", device)
gen = TSPTWDataGenerator(n_nodes=args.n_nodes, grid_size=args.grid_size, seed=args.seed)
print("Generating train set...")
train_coords, train_tw, train_sols = gen.generate_with_solutions(args.n_train)
print("Generating val set...")
val_coords, val_tw, val_sols = gen.generate_with_solutions(args.n_val, improve_with_2opt=False)
train_dataset = TSPTWDataset(train_coords, train_tw, train_sols)
val_dataset = TSPTWDataset(val_coords, val_tw, val_sols)
model = TSPTWGraphEncoder(
embedding_dim=args.embedding_dim,
n_layers=args.n_layers,
n_heads=4,
dropout=0.1
)
n_params = sum(p.numel() for p in model.parameters())
print(f"Model params: {n_params:,}")
trainer = TSPTWTrainer(model, device=device, lr=args.lr)
best_val = trainer.train(train_dataset, val_dataset, epochs=args.epochs, batch_size=args.batch_size,
save_dir='models/pretrained', early_stop_patience=6)
print("Training finished. Best val loss:", best_val)
if __name__ == '__main__':
main()