18 lines
631 B
Python
18 lines
631 B
Python
import torch
|
|
import numpy as np
|
|
import torch.nn as nn
|
|
|
|
|
|
class GaussianFourierProjection(nn.Module):
|
|
"""Gaussian random features for encoding time steps."""
|
|
|
|
def __init__(self, embed_dim, scale=30.):
|
|
super().__init__()
|
|
# Randomly sample weights during initialization. These weights are fixed
|
|
# during optimization and are not trainable.
|
|
self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
|
|
|
|
def forward(self, x):
|
|
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
|
|
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
|