config_file = 'signal_N_100_2'
config = ParticleGraphConfig.from_yaml(f'./config/{config_file}.yaml')
device = set_device("auto")Signaling system with 998 nodes
Signaling
Simulation
This script creates the seventh column of paper’s Figure 2. Simulation of a signaling network, 986 nodes, 17,865 edges, 2 types of nodes. Note 100 of datasets are generated to test training with multiple trials.
First, we load the configuration file and set the device.
The following model is used to simulate the signaling network with PyTorch Geometric.
class SignalingNetwork(pyg.nn.MessagePassing):
"""Interaction Network as proposed in this paper:
https://proceedings.neurips.cc/paper/2016/hash/3147da8ab4a0437c15ef51a5cc7f2dc4-Abstract.html"""
"""
Inputs
----------
data : a torch_geometric.data object
Returns
-------
pred : float
"""
def __init__(self, aggr_type=[], p=[], bc_dpos=[]):
super(SignalingNetwork, self).__init__(aggr=aggr_type)
self.p = p
self.bc_dpos = bc_dpos
def forward(self, data=[], return_all=False):
x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
edge_index, _ = pyg_utils.remove_self_loops(edge_index)
particle_type = x[:, 5].long()
parameters = self.p[particle_type]
b = parameters[:, 0:1]
c = parameters[:, 1:2]
u = x[:, 6:7]
msg = self.propagate(edge_index, u=u, edge_attr=edge_attr)
du = -b * u + c * torch.tanh(u) + msg
if return_all:
return du, -b * u + c * torch.tanh(u), msg
else:
return du
def message(self, u_j, edge_attr):
self.activation = torch.tanh(u_j)
self.u_j = u_j
return edge_attr[:, None] * torch.tanh(u_j)
def bc_pos(x):
return torch.remainder(x, 1.0)
def bc_dpos(x):
return torch.remainder(x - 0.5, 1.0) - 0.5The data is generated with the above Pytorch Geometric model. If the simulation is too large, you can decrease n_particles (multiple of 2) and n_nodes in “signal_N_100_2.yaml”
p = torch.squeeze(torch.tensor(config.simulation.params))
model = SignalingNetwork(aggr_type=config.graph_model.aggr_type, p=torch.squeeze(p), bc_dpos=bc_dpos)
generate_kwargs = dict(device=device, visualize=True, run_vizualized=0, style='color', alpha=1, erase=True, save=True, step=10)
train_kwargs = dict(device=device, erase=True)
test_kwargs = dict(device=device, visualize=True, style='color', verbose=False, best_model='20', run=0, step=1, save_velocity=True)
data_generate_synaptic(config, model, **generate_kwargs)Finally, we generate the figures that are shown in Figure 2. The frames of the first six datasets are saved in ‘decomp-gnn/paper_experiments/graphs_data/graphs_signal_N_100_2/Fig/’.



