import torch
import numpy as np
[docs]class GelNet:
"""
GelNet regularizer for linear decoder [Sokolov2016]_.
If ``P`` is set to Identity matrix, this is Elastic net.
``d`` needs to be a `{0,1}`-matrix.
If ``lamda1`` is 0, this is a L2 regularization.
If ``lambda2`` is 0, this is a L1 regularization.
Needs to be sequentially used in training loop.
Example
>>> loss = MSE(X_hat, X)
# Compute L2 term
>>> loss += GelNet.quadratic_update(self.decoder.weight)
>>> loss.backward()
>>> optimizer.step()
# L1 proximal operator update
>>> GelNet.proximal_update(self.decoder.weight)
Parameters
----------
lambda1
L1-regularization coefficient
lambda2
L2-regularization coefficient
P
Penalty matrix (eg. Gene network Laplacian)
d
Domain knowledge matrix (eg. mask)
lr
Learning rate
"""
def __init__(self,
lambda1: float,
lambda2: float,
P: np.ndarray,
d: np.ndarray = None,
lr: float = 1e-3,
use_gpu: bool = False):
self.l1 = lambda1
self.l2 = lambda2
if P is not None:
self.P = torch.FloatTensor(P)
if d is not None:
d = torch.Tensor(d).bool()
self.d = d
self.lr = lr
self.dev = torch.device('cuda') if use_gpu else torch.device('cpu')
[docs] def quadratic_update(self, weights):
"""
Computes the L2 term of GelNet
Parameters
----------
weights
Layer's weight matrix
"""
l = torch.tensor(0)
if self.l2 == 0:
return l
else:
# Sum over columns
#for k in range(weights.size(1)):
#l += (weights[:,k].t().matmul(self.P).matmul(weights[:,k]))
# Use einsum
l = torch.einsum('bi,ij,jb', weights.t(), self.P, weights)
return self.l2*l
[docs] def proximal_update(self, weights):
"""
Proximal operator for the L1 term inducing sparsity.
Parameters
----------
weights
Layer's weight matrix
"""
if self.l1 == 0:
return
else:
norm = self.l1 * self.lr
w = weights.data
w_update = w.clone()
w_geq = w_update > norm
w_leq = w_update < -1.0*norm
w_sparse = ~w_geq&~w_leq
if self.d is not None:
w_update[(self.d&w_geq)] -= norm
w_update[(self.d&w_leq)] += norm
w_update[(self.d&w_sparse)] = 0.
else:
w_update[w_geq] -= norm
w_update[w_leq] += norm
w_update[w_sparse] = 0.
weights.data = w_update
return
[docs]class LassoRegularizer:
"""
Lasso (L1) regularizer for linear decoder.
Similar to [Rybakov2020]_ lasso regularization.
Parameters
----------
lambda1
L1-regularization coefficient
d
Domain knowledge matrix (eg. mask)
lr
Learning rate
"""
def __init__(self,
lambda1: float,
lr: float,
d: np.ndarray = None,
use_gpu: bool = False):
self.l1 = lambda1
self.lr = lr
if d is not None:
d = torch.Tensor(d).bool()
self.d = d
self.dev = torch.device('cuda') if use_gpu else torch.device('cpu')
[docs] def quadratic_update(self, weights):
""" Not applicable (identity) """
return torch.tensor(0)
[docs] def proximal_update(self, weights):
"""
Proximal operator for the L1 term inducing sparsity.
Parameters
----------
weights
Layer's weight matrix
"""
if self.l1 == 0:
return
else:
norm = self.l1 * self.lr
w = weights.data
w_update = w.clone()
#norm_w = norm * torch.ones(w.size(), device=self.dev)
#pos = torch.min(norm_w, norm * torch.clamp(w, min=0))
#neg = torch.min(norm_w, -1.0 * norm * torch.clamp(w, max=0))
#if self.d is not None:
#w_update[self.d] = w[self.d] - pos[self.d] + neg[self.d]
#else:
#w_update = w - pos + neg
w_geq = w_update > norm
w_leq = w_update < -1.0*norm
w_sparse = ~w_geq&~w_leq
if self.d is not None:
w_update[(self.d&w_geq)] -= norm
w_update[(self.d&w_leq)] += norm
w_update[(self.d&w_sparse)] = 0.
else:
w_update[w_geq] -= norm
w_update[w_leq] += norm
w_update[w_sparse] = 0.
weights.data = w_update
return