Generate synthetic neural activity data with Gaussian noise using the PDE_N2 model. This creates the training dataset with 1000 neurons of 4 different types over 100,000 time points.
Outputs:
Supp. Fig 10b: Activity time series used for GNN training
Train the GNN to learn connectivity \(W\), latent embeddings \(\mathbf{a}_i\), and functions \(\phi^*, \psi^*\). The GNN learns to predict \(dx_i/dt\) from the noisy observed activity \(x_i\).
The GNN optimizes the update rule (Equation 3 from the paper):
where \(\phi^*\) and \(\psi^*\) are MLPs (ReLU, hidden dim=64, 3 layers). \(\mathbf{a}_i\) is a learnable 2D latent vector per neuron, and \(W\) is the learnable connectivity matrix.
Code
# STEP 2: TRAINprint()print("-"*80)print("STEP 2: TRAIN - Training GNN to learn W, embeddings, phi, psi from noisy data")print("-"*80)# Check if trained model already exists (any .pt file in models folder)import globmodel_files = glob.glob(f'{log_dir}/models/*.pt')if model_files:print(f"trained model already exists at {log_dir}/models/")print("skipping training (delete models folder to retrain)")else:print(f"training for {config.training.n_epochs} epochs, {config.training.n_runs} run(s)")print(f"learning: connectivity W, latent vectors a_i, functions phi* and psi*")print(f"models: {log_dir}/models/")print(f"training plots: {log_dir}/tmp_training")print(f"tensorboard: tensorboard --logdir {log_dir}/")print() data_train( config=config, erase=False, best_model=best_model, style='color', device=device )
Step 3: GNN Evaluation
Figures matching Supplementary Figure 10 from the paper.
Figure panels:
Supp. Fig 10d: Learned connectivity matrix
Supp. Fig 10e: Comparison of learned vs true connectivity