Skip to content
Snippets Groups Projects
Commit 4739f482 authored by Matteo Argenton's avatar Matteo Argenton
Browse files

add support for depth and hidden dim of mlp in CGNN

parent a77e4a25
No related branches found
Tags v0.1.0
No related merge requests found
......@@ -4,6 +4,8 @@ from flax import linen as nn
#### Define Edge Network
class EdgeNet(nn.Module):
hid_dim: int
c_depth: int
inner_mlp_dim: int
@nn.compact
def __call__(self, X, Ri, Ro):
......@@ -13,8 +15,9 @@ class EdgeNet(nn.Module):
B = jnp.concatenate([bo, bi], axis = 1)
B = nn.Dense(self.hid_dim)(B)
B = nn.tanh(B)
for _ in range(self.c_depth):
B = nn.Dense(self.inner_mlp_dim)(B)
B = nn.tanh(B)
B = nn.Dense(1)(B)
B = nn.sigmoid(B)
......@@ -23,6 +26,8 @@ class EdgeNet(nn.Module):
#### Define Node Network
class NodeNet(nn.Module):
hid_dim: int
c_depth: int
inner_mlp_dim: int
@nn.compact
def __call__(self, X, e, Ri, Ro):
......@@ -37,8 +42,9 @@ class NodeNet(nn.Module):
mo = jnp.tensordot(Rwo, bi, axes=([1],[0]))
M = jnp.concatenate([mi, mo, X], axis=1)
M = nn.Dense(self.hid_dim)(M)
M = nn.tanh(M)
for _ in range(self.c_depth):
M = nn.Dense(self.inner_mlp_dim)(M)
M = nn.tanh(M)
M = nn.Dense(self.hid_dim)(M)
M = nn.sigmoid(M)
......
......@@ -138,7 +138,7 @@ def model_setup(config, session_id=0):
NodeNet(config['NN_qc']['n_qubits'], config['hid_dim'], QLayer(config['NN_qc'], session_id, config['backend'])))
elif config['network'] == 'CGNN':
from qnetworks.CGNN import GNN, EdgeNet, NodeNet
model = GNN(config['hid_dim'], config['n_iters'], EdgeNet(config['hid_dim']), NodeNet(config['hid_dim']))
model = GNN(config['hid_dim'], config['n_iters'], EdgeNet(config['hid_dim'], config['c_depth'], config['inner_mlp_dim']), NodeNet(config['hid_dim'], config['c_depth'], config['inner_mlp_dim']))
else:
print('Wrong network specification!')
sys.exit()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment