src.acoustools.Optimise.Constraints
Constrains for Phases used in Solver.gradient_descent_solver Must have signature phases, **params -> phases
1''' 2Constrains for Phases used in Solver.gradient_descent_solver 3Must have signature phases, **params -> phases 4''' 5 6import torch 7from torch import Tensor 8 9def constrain_phase_only(phases: Tensor, **params) -> Tensor: 10 ''' 11 Normalises amplitude to 1 12 :param phase: Phases 13 :return: Hologram 14 ''' 15 return phases / torch.abs(phases) 16 17def constrant_normalise_amplitude(phases: Tensor, **params)-> Tensor: 18 ''' 19 Constrains by dividing by `torch.max(torch.abs(phases))` 20 :param phase: Phases 21 :return: Hologram 22 ''' 23 return phases / torch.max(torch.abs(phases)) 24 25def constrain_sigmoid_amplitude(phases: Tensor, **params)-> Tensor: 26 ''' 27 Constrains by dividing by passing through signmoid 28 :param phase: Phases 29 :return: Hologram 30 ''' 31 amplitudes = torch.abs(phases) 32 norm_holo = phases / amplitudes 33 con_amp = 0.5 * torch.sigmoid(amplitudes) + 1/2 34 # print(torch.abs(norm_holo * sin_amp)) 35 return norm_holo * con_amp 36 37def constrain_clamp_amp(phases: Tensor, **params)-> Tensor: 38 ''' 39 Constrains by dividing by clamping 40 :param phase: Phases 41 :return: Hologram 42 ''' 43 amplitudes = torch.abs(phases) 44 norm_holo = phases / amplitudes 45 clamp_amp = torch.clamp(amplitudes,min=0,max=1) 46 return norm_holo * clamp_amp 47 48def normalise_amplitude_normal(phases: Tensor, **params)-> Tensor: 49 ''' 50 Constrains by dividing by z-score 51 :param phase: Phases 52 :return: Hologram 53 ''' 54 amplitudes = torch.abs(phases) 55 norm_holo = phases / amplitudes 56 norm_dist_amp = (amplitudes - torch.min(amplitudes,dim=1,keepdim=True).values) / (torch.max(amplitudes,dim=1,keepdim=True).values - torch.min(amplitudes,dim=1,keepdim=True).values) 57 # print(torch.abs(norm_holo * norm_dist_amp)) 58 return norm_holo * norm_dist_amp 59 60def sine_amplitude(phases: Tensor, **params)-> Tensor: 61 amplitudes = torch.abs(phases) 62 sin_amp = torch.sin(amplitudes) 63 angles = torch.angle(phases) 64 return sin_amp * torch.exp(1j*angles) 65 66 67def sine_amplitude_square(phases: Tensor, **params)-> Tensor: 68 amplitudes = torch.abs(phases) 69 sin_amp = torch.sin(amplitudes)**2 70 angles = torch.angle(phases) 71 return sin_amp * torch.exp(1j*angles) 72 73def sine_amplitude_pi_square(phases: Tensor, **params)-> Tensor: 74 amplitudes = torch.abs(phases) 75 sin_amp = torch.sin(amplitudes * torch.pi - torch.pi/2)**2 76 angles = torch.angle(phases) 77 return sin_amp * torch.exp(1j*angles) 78 79def sine_amplitude_pi(phases: Tensor, **params)-> Tensor: 80 amplitudes = torch.abs(phases) 81 sin_amp = torch.sin(amplitudes * torch.pi/2) 82 angles = torch.angle(phases) 83 return sin_amp * torch.exp(1j*angles) 84 85def tanh_amplitude(phases:Tensor, **params): 86 amplitudes = torch.abs(phases) 87 amp = torch.tanh(0.1*amplitudes) 88 angles = torch.angle(phases) 89 return amp * torch.exp(1j*angles) 90 91def norm_only_bottom(phases:Tensor, **params): 92 assert phases.shape[1] == 512 93 top = phases[:,0:256] / torch.max(torch.abs(phases[:,0:256])) 94 bottom = phases[:,256:] / torch.abs(phases[:,256:] ) 95 phases = torch.cat([top,bottom], dim=1) 96 97 return phases
def
constrain_phase_only(phases: torch.Tensor, **params) -> torch.Tensor:
10def constrain_phase_only(phases: Tensor, **params) -> Tensor: 11 ''' 12 Normalises amplitude to 1 13 :param phase: Phases 14 :return: Hologram 15 ''' 16 return phases / torch.abs(phases)
Normalises amplitude to 1
Parameters
- phase: Phases
Returns
Hologram
def
constrant_normalise_amplitude(phases: torch.Tensor, **params) -> torch.Tensor:
18def constrant_normalise_amplitude(phases: Tensor, **params)-> Tensor: 19 ''' 20 Constrains by dividing by `torch.max(torch.abs(phases))` 21 :param phase: Phases 22 :return: Hologram 23 ''' 24 return phases / torch.max(torch.abs(phases))
Constrains by dividing by torch.max(torch.abs(phases))
Parameters
- phase: Phases
Returns
Hologram
def
constrain_sigmoid_amplitude(phases: torch.Tensor, **params) -> torch.Tensor:
26def constrain_sigmoid_amplitude(phases: Tensor, **params)-> Tensor: 27 ''' 28 Constrains by dividing by passing through signmoid 29 :param phase: Phases 30 :return: Hologram 31 ''' 32 amplitudes = torch.abs(phases) 33 norm_holo = phases / amplitudes 34 con_amp = 0.5 * torch.sigmoid(amplitudes) + 1/2 35 # print(torch.abs(norm_holo * sin_amp)) 36 return norm_holo * con_amp
Constrains by dividing by passing through signmoid
Parameters
- phase: Phases
Returns
Hologram
def
constrain_clamp_amp(phases: torch.Tensor, **params) -> torch.Tensor:
38def constrain_clamp_amp(phases: Tensor, **params)-> Tensor: 39 ''' 40 Constrains by dividing by clamping 41 :param phase: Phases 42 :return: Hologram 43 ''' 44 amplitudes = torch.abs(phases) 45 norm_holo = phases / amplitudes 46 clamp_amp = torch.clamp(amplitudes,min=0,max=1) 47 return norm_holo * clamp_amp
Constrains by dividing by clamping
Parameters
- phase: Phases
Returns
Hologram
def
normalise_amplitude_normal(phases: torch.Tensor, **params) -> torch.Tensor:
49def normalise_amplitude_normal(phases: Tensor, **params)-> Tensor: 50 ''' 51 Constrains by dividing by z-score 52 :param phase: Phases 53 :return: Hologram 54 ''' 55 amplitudes = torch.abs(phases) 56 norm_holo = phases / amplitudes 57 norm_dist_amp = (amplitudes - torch.min(amplitudes,dim=1,keepdim=True).values) / (torch.max(amplitudes,dim=1,keepdim=True).values - torch.min(amplitudes,dim=1,keepdim=True).values) 58 # print(torch.abs(norm_holo * norm_dist_amp)) 59 return norm_holo * norm_dist_amp
Constrains by dividing by z-score
Parameters
- phase: Phases
Returns
Hologram
def
sine_amplitude(phases: torch.Tensor, **params) -> torch.Tensor:
def
sine_amplitude_square(phases: torch.Tensor, **params) -> torch.Tensor:
def
sine_amplitude_pi_square(phases: torch.Tensor, **params) -> torch.Tensor:
def
sine_amplitude_pi(phases: torch.Tensor, **params) -> torch.Tensor:
def
tanh_amplitude(phases: torch.Tensor, **params):
def
norm_only_bottom(phases: torch.Tensor, **params):