src.acoustools.Optimise.Constraints

Constrains for Phases used in Solver.gradient_descent_solver Must have signature phases, **params -> phases

 1'''
 2Constrains for Phases used in Solver.gradient_descent_solver
 3Must have signature phases, **params -> phases 
 4'''
 5
 6import torch
 7from torch import Tensor
 8
 9def constrain_phase_only(phases: Tensor, **params) -> Tensor:
10    '''
11    Normalises amplitude to 1
12    :param phase: Phases
13    :return: Hologram
14    '''
15    return phases / torch.abs(phases)
16
17def constrant_normalise_amplitude(phases: Tensor, **params)-> Tensor:
18    '''
19    Constrains by dividing by `torch.max(torch.abs(phases))`
20    :param phase: Phases
21    :return: Hologram
22    '''
23    return phases / torch.max(torch.abs(phases))
24
25def constrain_sigmoid_amplitude(phases: Tensor, **params)-> Tensor:
26    '''
27    Constrains by dividing by passing through signmoid
28    :param phase: Phases
29    :return: Hologram
30    '''
31    amplitudes = torch.abs(phases)
32    norm_holo = phases / amplitudes
33    con_amp = 0.5 * torch.sigmoid(amplitudes) + 1/2
34    # print(torch.abs(norm_holo * sin_amp))
35    return norm_holo * con_amp
36
37def constrain_clamp_amp(phases: Tensor, **params)-> Tensor:
38    '''
39    Constrains by dividing by clamping
40    :param phase: Phases
41    :return: Hologram
42    '''
43    amplitudes = torch.abs(phases)
44    norm_holo = phases / amplitudes
45    clamp_amp = torch.clamp(amplitudes,min=0,max=1)
46    return norm_holo * clamp_amp
47
48def normalise_amplitude_normal(phases: Tensor, **params)-> Tensor:
49    '''
50    Constrains by dividing by z-score
51    :param phase: Phases
52    :return: Hologram
53    '''
54    amplitudes = torch.abs(phases)
55    norm_holo = phases / amplitudes
56    norm_dist_amp = (amplitudes - torch.min(amplitudes,dim=1,keepdim=True).values) / (torch.max(amplitudes,dim=1,keepdim=True).values - torch.min(amplitudes,dim=1,keepdim=True).values)
57    # print(torch.abs(norm_holo * norm_dist_amp))
58    return norm_holo * norm_dist_amp
59
60def sine_amplitude(phases: Tensor, **params)-> Tensor:
61    amplitudes = torch.abs(phases)
62    sin_amp = torch.sin(amplitudes)
63    angles = torch.angle(phases)
64    return sin_amp * torch.exp(1j*angles)
65
66
67def sine_amplitude_square(phases: Tensor, **params)-> Tensor:
68    amplitudes = torch.abs(phases)
69    sin_amp = torch.sin(amplitudes)**2
70    angles = torch.angle(phases)
71    return sin_amp * torch.exp(1j*angles)
72
73def sine_amplitude_pi_square(phases: Tensor, **params)-> Tensor:
74    amplitudes = torch.abs(phases)
75    sin_amp = torch.sin(amplitudes * torch.pi - torch.pi/2)**2
76    angles = torch.angle(phases)
77    return sin_amp * torch.exp(1j*angles)
78
79def sine_amplitude_pi(phases: Tensor, **params)-> Tensor:
80    amplitudes = torch.abs(phases)
81    sin_amp = torch.sin(amplitudes * torch.pi/2)
82    angles = torch.angle(phases)
83    return sin_amp * torch.exp(1j*angles)
84
85def tanh_amplitude(phases:Tensor, **params):
86    amplitudes = torch.abs(phases)
87    amp = torch.tanh(0.1*amplitudes)
88    angles = torch.angle(phases)
89    return amp * torch.exp(1j*angles)
90
91def norm_only_bottom(phases:Tensor, **params):
92    assert phases.shape[1] == 512
93    top = phases[:,0:256] / torch.max(torch.abs(phases[:,0:256]))
94    bottom = phases[:,256:] / torch.abs(phases[:,256:] ) 
95    phases = torch.cat([top,bottom], dim=1)
96
97    return phases
def constrain_phase_only(phases: torch.Tensor, **params) -> torch.Tensor:
10def constrain_phase_only(phases: Tensor, **params) -> Tensor:
11    '''
12    Normalises amplitude to 1
13    :param phase: Phases
14    :return: Hologram
15    '''
16    return phases / torch.abs(phases)

Normalises amplitude to 1

Parameters
  • phase: Phases
Returns

Hologram

def constrant_normalise_amplitude(phases: torch.Tensor, **params) -> torch.Tensor:
18def constrant_normalise_amplitude(phases: Tensor, **params)-> Tensor:
19    '''
20    Constrains by dividing by `torch.max(torch.abs(phases))`
21    :param phase: Phases
22    :return: Hologram
23    '''
24    return phases / torch.max(torch.abs(phases))

Constrains by dividing by torch.max(torch.abs(phases))

Parameters
  • phase: Phases
Returns

Hologram

def constrain_sigmoid_amplitude(phases: torch.Tensor, **params) -> torch.Tensor:
26def constrain_sigmoid_amplitude(phases: Tensor, **params)-> Tensor:
27    '''
28    Constrains by dividing by passing through signmoid
29    :param phase: Phases
30    :return: Hologram
31    '''
32    amplitudes = torch.abs(phases)
33    norm_holo = phases / amplitudes
34    con_amp = 0.5 * torch.sigmoid(amplitudes) + 1/2
35    # print(torch.abs(norm_holo * sin_amp))
36    return norm_holo * con_amp

Constrains by dividing by passing through signmoid

Parameters
  • phase: Phases
Returns

Hologram

def constrain_clamp_amp(phases: torch.Tensor, **params) -> torch.Tensor:
38def constrain_clamp_amp(phases: Tensor, **params)-> Tensor:
39    '''
40    Constrains by dividing by clamping
41    :param phase: Phases
42    :return: Hologram
43    '''
44    amplitudes = torch.abs(phases)
45    norm_holo = phases / amplitudes
46    clamp_amp = torch.clamp(amplitudes,min=0,max=1)
47    return norm_holo * clamp_amp

Constrains by dividing by clamping

Parameters
  • phase: Phases
Returns

Hologram

def normalise_amplitude_normal(phases: torch.Tensor, **params) -> torch.Tensor:
49def normalise_amplitude_normal(phases: Tensor, **params)-> Tensor:
50    '''
51    Constrains by dividing by z-score
52    :param phase: Phases
53    :return: Hologram
54    '''
55    amplitudes = torch.abs(phases)
56    norm_holo = phases / amplitudes
57    norm_dist_amp = (amplitudes - torch.min(amplitudes,dim=1,keepdim=True).values) / (torch.max(amplitudes,dim=1,keepdim=True).values - torch.min(amplitudes,dim=1,keepdim=True).values)
58    # print(torch.abs(norm_holo * norm_dist_amp))
59    return norm_holo * norm_dist_amp

Constrains by dividing by z-score

Parameters
  • phase: Phases
Returns

Hologram

def sine_amplitude(phases: torch.Tensor, **params) -> torch.Tensor:
61def sine_amplitude(phases: Tensor, **params)-> Tensor:
62    amplitudes = torch.abs(phases)
63    sin_amp = torch.sin(amplitudes)
64    angles = torch.angle(phases)
65    return sin_amp * torch.exp(1j*angles)
def sine_amplitude_square(phases: torch.Tensor, **params) -> torch.Tensor:
68def sine_amplitude_square(phases: Tensor, **params)-> Tensor:
69    amplitudes = torch.abs(phases)
70    sin_amp = torch.sin(amplitudes)**2
71    angles = torch.angle(phases)
72    return sin_amp * torch.exp(1j*angles)
def sine_amplitude_pi_square(phases: torch.Tensor, **params) -> torch.Tensor:
74def sine_amplitude_pi_square(phases: Tensor, **params)-> Tensor:
75    amplitudes = torch.abs(phases)
76    sin_amp = torch.sin(amplitudes * torch.pi - torch.pi/2)**2
77    angles = torch.angle(phases)
78    return sin_amp * torch.exp(1j*angles)
def sine_amplitude_pi(phases: torch.Tensor, **params) -> torch.Tensor:
80def sine_amplitude_pi(phases: Tensor, **params)-> Tensor:
81    amplitudes = torch.abs(phases)
82    sin_amp = torch.sin(amplitudes * torch.pi/2)
83    angles = torch.angle(phases)
84    return sin_amp * torch.exp(1j*angles)
def tanh_amplitude(phases: torch.Tensor, **params):
86def tanh_amplitude(phases:Tensor, **params):
87    amplitudes = torch.abs(phases)
88    amp = torch.tanh(0.1*amplitudes)
89    angles = torch.angle(phases)
90    return amp * torch.exp(1j*angles)
def norm_only_bottom(phases: torch.Tensor, **params):
92def norm_only_bottom(phases:Tensor, **params):
93    assert phases.shape[1] == 512
94    top = phases[:,0:256] / torch.max(torch.abs(phases[:,0:256]))
95    bottom = phases[:,256:] / torch.abs(phases[:,256:] ) 
96    phases = torch.cat([top,bottom], dim=1)
97
98    return phases