Supplementary Figures 8 and 9: Sparse connectivity (5% to 100%)

Neural Activity
Simulation
GNN Training
Author

Cédric Allier, Stephan Saalfeld

This script reproduces the panels of paper’s Supplementary Figures 8 and 9. Performance of GNN for connectivity matrices with varying sparsity levels. This notebook displays connectivity matrix comparison, \(\phi^*\) plots, \(\psi^*\) plots, and learned embedding for each sparsity level.

Simulation parameters (constant across all experiments):

The simulation follows Equation 2 from the paper:

\[\frac{dx_i}{dt} = -\frac{x_i}{\tau_i} + s_i \cdot \tanh(x_i) + g_i \cdot \sum_j W_{ij} \cdot \tanh(x_j)\]

Variable: Connectivity sparsity

Config Sparsity
signal_fig_supp_8 5%
signal_fig_supp_8_3 10%
signal_fig_supp_8_2 20%
signal_fig_supp_8_1 50%
signal_fig_2 100%

Configuration

Code
import glob

print()
print("=" * 80)
print("Supplementary Figure 8: Effect of Connectivity Sparsity")
print("=" * 80)

# All configs to process (config_name, sparsity)
config_list = [
    ('signal_fig_supp_8', '5%'),
    ('signal_fig_supp_8_3', '10%'),
    ('signal_fig_supp_8_2', '20%'),
    ('signal_fig_supp_8_1', '50%'),
    ('signal_fig_2', '100%'),
]

device = []
best_model = ''
config_root = "./config"

Steps 1-3: Generate, Train, and Plot for all configs

Loop over all sparsity levels: generate data, train GNN, and generate plots. Skips steps if data/models already exist.

Code
for config_file_, sparsity in config_list:
    print()
    print("=" * 80)
    print(f"Processing: {config_file_} ({sparsity} sparsity)")
    print("=" * 80)

    config_file, pre_folder = add_pre_folder(config_file_)

    # Load config
    config = NeuralGraphConfig.from_yaml(f"{config_root}/{config_file}.yaml")
    config.config_file = config_file
    config.dataset = config_file

    if device == []:
        device = set_device(config.training.device)

    log_dir = f'./log/{config_file}'
    graphs_dir = f'./graphs_data/{config_file}'

    # STEP 1: GENERATE
    print()
    print("-" * 80)
    print("STEP 1: GENERATE - Simulating neural activity")
    print("-" * 80)

    data_file = f'{graphs_dir}/x_list_0.npy'
    if os.path.exists(data_file):
        print(f"data already exists at {graphs_dir}/")
        print("skipping simulation, regenerating figures...")
        data_generate(
            config,
            device=device,
            visualize=False,
            run_vizualized=0,
            style="color",
            alpha=1,
            erase=False,
            bSave=True,
            step=2,
            regenerate_plots_only=True,
        )
    else:
        print(f"simulating {config.simulation.n_neurons} neurons, {config.simulation.n_frames} frames")
        print(f"output: {graphs_dir}/")
        data_generate(
            config,
            device=device,
            visualize=False,
            run_vizualized=0,
            style="color",
            alpha=1,
            erase=False,
            bSave=True,
            step=2,
        )

    # STEP 2: TRAIN
    print()
    print("-" * 80)
    print("STEP 2: TRAIN - Training GNN")
    print("-" * 80)

    model_files = glob.glob(f'{log_dir}/models/*.pt')
    if model_files:
        print(f"trained model already exists at {log_dir}/models/")
        print("skipping training (delete models folder to retrain)")
    else:
        print(f"training for {config.training.n_epochs} epochs")
        print(f"sparsity: {sparsity}")
        data_train(
            config=config,
            erase=False,
            best_model=best_model,
            style='color',
            device=device
        )

    # STEP 3: PLOT
    print()
    print("-" * 80)
    print("STEP 3: PLOT - Generating figures")
    print("-" * 80)

    folder_name = f'{log_dir}/tmp_results/'
    os.makedirs(folder_name, exist_ok=True)

    data_plot(
        config=config,
        config_file=config_file,
        epoch_list=['best'],
        style='color',
        extended='plots',
        device=device,
        apply_weight_correction=True,
        plot_eigen_analysis=False
    )

    # STEP 4: TRAINING PROGRESSION (R² over iterations)
    print()
    print("-" * 80)
    print("STEP 4: TRAINING PROGRESSION - Computing R² over iterations")
    print("-" * 80)

    r2_file = f'{log_dir}/results/all/r2_over_iterations.json'
    if os.path.exists(r2_file):
        print(f"R² data already exists at {r2_file}")
        print("skipping (delete results/all/ folder to recompute)")
    else:
        data_plot(
            config=config,
            config_file=config_file,
            epoch_list=['all'],
            style='color',
            extended='plots',
            device=device,
            apply_weight_correction=True,
            plot_eigen_analysis=False,
        )

Activity Time Series

Sample of 100 time series for each sparsity level.

Supp. Fig 8b: Sample of 100 time series (5% sparsity)

Sample of 100 time series (10% sparsity)

Sample of 100 time series (20% sparsity)

Sample of 100 time series (50% sparsity)

Sample of 100 time series (100% connectivity)

True Connectivity Matrix \(W_{ij}\)

True connectivity matrix for each sparsity level.

Supp. Fig 8c: True connectivity \(W_{ij}\) (5% sparsity)

True connectivity \(W_{ij}\) (10% sparsity)

True connectivity \(W_{ij}\) (20% sparsity)

True connectivity \(W_{ij}\) (50% sparsity)

True connectivity \(W_{ij}\) (100% connectivity)

Connectivity Matrix Comparison

Learned vs true connectivity matrix \(W_{ij}\) after training. The scatter plot shows \(R^2\) and slope metrics.

Supp. Fig 8e: Connectivity comparison (5% sparsity)

Connectivity comparison (10% sparsity)

Connectivity comparison (20% sparsity)

Connectivity comparison (50% sparsity)

Connectivity comparison (100% connectivity)

Update Function \(\phi^*(\mathbf{a}_i, x)\) (MLP0)

Learned update functions after training. Each curve represents one neuron. Colors indicate true neuron types. True functions overlaid in gray.

Supp. Fig 8g: Update functions \(\phi^*(a_i, x)\) (5% sparsity). True functions are overlaid in light gray.

Update functions \(\phi^*(a_i, x)\) (10% sparsity). True functions are overlaid in light gray.

Update functions \(\phi^*(a_i, x)\) (20% sparsity). True functions are overlaid in light gray.

Update functions \(\phi^*(a_i, x)\) (50% sparsity). True functions are overlaid in light gray.

Update functions \(\phi^*(a_i, x)\) (100% connectivity). True functions are overlaid in light gray.

Transfer Function \(\psi^*(x)\) (MLP1)

Learned transfer function after training, normalized to max=1. True function overlaid in gray.

Supp. Fig 8h: Transfer function \(\psi^*(x)\) (5% sparsity). True function overlaid in light gray.

Transfer function \(\psi^*(x)\) (10% sparsity). True function overlaid in light gray.

Transfer function \(\psi^*(x)\) (20% sparsity). True function overlaid in light gray.

Transfer function \(\psi^*(x)\) (50% sparsity). True function overlaid in light gray.

Transfer function \(\psi^*(x)\) (100% connectivity). True function overlaid in light gray.

Latent Embeddings \(\mathbf{a}_i\)

Learned latent vectors for all neurons. Colors indicate true neuron types.

Supp. Fig 8f: Latent embeddings \(a_i\) (5% sparsity).

Latent embeddings \(a_i\) (10% sparsity).

Latent embeddings \(a_i\) (20% sparsity).

Latent embeddings \(a_i\) (50% sparsity).

Latent embeddings \(a_i\) (100% connectivity).

R² Connectivity Over Training Iterations

1000 densely connected neurons with 4 neuron-dependent update functions. The plot displays \(R^2\) for the comparison between true and learned connectivity matrices \(W_{ij}\) as a function of training iterations for different connectivity filling factors (colors). All comparisons are made at equal numbers of gradient descent iterations.

Code
print()
print("-" * 80)
print("Generating R² over iterations comparison plot")
print("-" * 80)
output_r2 = plot_r2_over_iterations(
    config_list=config_list,
    output_path='./log/signal/tmp_results/r2_over_iterations_sparsity.png',
    device=device,
)

1000 densely connected neurons with 4 neuron-dependent update functions. \(R^2\) for the comparison between true and learned connectivity matrices \(W_{ij}\) as a function of training iterations for different connectivity filling factors (colors). All comparisons are made at equal numbers of gradient descent iterations.