This script reproduce the panels of paper’s Supplementary Figure 7 . Performance scales with the length of the training series. This notebook displays connectivity matrix comparison, \(\phi^*\) plots, \(\psi^*\) plots, and learned embedding for each dataset size.
Simulation parameters (constant across all experiments):
N_neurons: 1000
N_types: 4 parameterized by \(\tau_i\) ={0.5,1}, \(s_i\) ={1,2} and \(g_i\) =10
Connectivity: 100% (dense), Lorentz distribution
The simulation follows Equation 2 from the paper:
\[\frac{dx_i}{dt} = -\frac{x_i}{\tau_i} + s_i \cdot \tanh(x_i) + g_i \cdot \sum_j W_{ij} \cdot \tanh(x_j)\]
Variable: Training dataset size (n_frames)
signal_fig_2
100,000
signal_fig_supp_7_1
50,000
signal_fig_supp_7_2
40,000
signal_fig_supp_7_3
30,000
signal_fig_supp_7_4
20,000
signal_fig_supp_7_5
10,000
Configuration
Code
import glob
print ()
print ("=" * 80 )
print ("Supplementary Figure 7: Effect of Training Dataset Size" )
print ("=" * 80 )
# All configs to process (config_name, n_frames)
config_list = [
('signal_fig_2' , 100000 ),
('signal_fig_supp_7_1' , 50000 ),
('signal_fig_supp_7_2' , 40000 ),
('signal_fig_supp_7_3' , 30000 ),
('signal_fig_supp_7_4' , 20000 ),
('signal_fig_supp_7_5' , 10000 ),
]
device = []
best_model = ''
config_root = "./config"
Steps 1-3: Generate, Train, and Plot for all configs
Loop over all dataset sizes: generate data, train GNN, and generate plots. Skips steps if data/models already exist.
Code
for config_file_, n_frames in config_list:
print ()
print ("=" * 80 )
print (f"Processing: { config_file_} (n_frames= { n_frames:,} )" )
print ("=" * 80 )
config_file, pre_folder = add_pre_folder(config_file_)
# Load config
config = NeuralGraphConfig.from_yaml(f" { config_root} / { config_file} .yaml" )
config.config_file = config_file
config.dataset = config_file
if device == []:
device = set_device(config.training.device)
log_dir = f'./log/ { config_file} '
graphs_dir = f'./graphs_data/ { config_file} '
# STEP 1: GENERATE
print ()
print ("-" * 80 )
print ("STEP 1: GENERATE - Simulating neural activity" )
print ("-" * 80 )
data_file = f' { graphs_dir} /x_list_0.npy'
if os.path.exists(data_file):
print (f"data already exists at { graphs_dir} /" )
print ("skipping simulation, regenerating figures..." )
data_generate(
config,
device= device,
visualize= False ,
run_vizualized= 0 ,
style= "color" ,
alpha= 1 ,
erase= False ,
bSave= True ,
step= 2 ,
regenerate_plots_only= True ,
)
else :
print (f"simulating { config. simulation. n_neurons} neurons, { config. simulation. n_frames} frames" )
print (f"output: { graphs_dir} /" )
data_generate(
config,
device= device,
visualize= False ,
run_vizualized= 0 ,
style= "color" ,
alpha= 1 ,
erase= False ,
bSave= True ,
step= 2 ,
)
# STEP 2: TRAIN
print ()
print ("-" * 80 )
print ("STEP 2: TRAIN - Training GNN" )
print ("-" * 80 )
model_files = glob.glob(f' { log_dir} /models/*.pt' )
if model_files:
print (f"trained model already exists at { log_dir} /models/" )
print ("skipping training (delete models folder to retrain)" )
else :
print (f"training for { config. training. n_epochs} epochs" )
print (f"n_frames: { config. simulation. n_frames} " )
data_train(
config= config,
erase= False ,
best_model= best_model,
style= 'color' ,
device= device
)
# STEP 3: PLOT
print ()
print ("-" * 80 )
print ("STEP 3: PLOT - Generating figures" )
print ("-" * 80 )
folder_name = f' { log_dir} /tmp_results/'
os.makedirs(folder_name, exist_ok= True )
data_plot(
config= config,
config_file= config_file,
epoch_list= ['best' ],
style= 'color' ,
extended= 'plots' ,
device= device,
apply_weight_correction= True ,
plot_eigen_analysis= False
)
# STEP 4: TRAINING PROGRESSION (R² over iterations)
print ()
print ("-" * 80 )
print ("STEP 4: TRAINING PROGRESSION - Computing R² over iterations" )
print ("-" * 80 )
r2_file = f' { log_dir} /results/all/r2_over_iterations.json'
if os.path.exists(r2_file):
print (f"R² data already exists at { r2_file} " )
print ("skipping (delete results/all/ folder to recompute)" )
else :
data_plot(
config= config,
config_file= config_file,
epoch_list= ['all' ],
style= 'color' ,
extended= 'plots' ,
device= device,
apply_weight_correction= True ,
plot_eigen_analysis= False ,
)
Connectivity Matrix Comparison
Learned vs true connectivity matrix \(W_{ij}\) after training. The scatter plot shows \(R^2\) and slope metrics.
Update Function \(\phi^*(\mathbf{a}_i, x)\) (MLP0)
Learned update functions after training. Each curve represents one neuron. Colors indicate true neuron types. True functions overlaid in gray.
Transfer Function \(\psi^*(x)\) (MLP1)
Learned transfer function after training, normalized to max=1. True function overlaid in gray.
Latent Embeddings \(\mathbf{a}_i\)
Learned latent vectors for all neurons. Colors indicate true neuron types.
R² Connectivity Over Training Iterations
\(R^2\) between learned and true connectivity \(W_{ij}\) plotted as a function of training iterations for each dataset size.
Code
print ()
print ("-" * 80 )
print ("Generating R² over iterations comparison plot" )
print ("-" * 80 )
output_r2 = plot_r2_over_iterations(
config_list= config_list,
output_path= './log/signal/tmp_results/r2_over_iterations.png' ,
device= device,
)