Pytorch Snippet

检测device # Get cpu, gpu or mps device for training. device = ( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) print(f"Using {device} device") Reproducibility def reproducibility(seed: int = 8848): import random import numpy as np import torch random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) 初始化权重 def init_params(model): for p in model.parameters(): if p.dim() > 1: nn.init.xavier_normal_(p) else: nn.init.normal_(p)

May 28, 2024 · 1 min · 58 words · Me