Skip to content

API reference

compile_graph

compile_graph(
    edgelist: DataFrame,
    backend: str = "feedforward",
    quiet: bool = False,
)

Compile an edgelist into a sparse PyTorch model and compilation artifact.

The edgelist defines the architecture graph. Each row describes one directed connection from source to target in the direction of computation. In other words, edges should point from input feature nodes toward hidden nodes and output nodes.

Input features are inferred as graph nodes with no incoming edges. Output nodes are inferred as graph nodes with no outgoing edges. The returned artifact stores the inferred input feature names in artifact.feature_names. Tensors passed to the compiled model must have columns in that exact order.

Parameters:

Name Type Description Default
edgelist DataFrame

Edge table with required columns "source" and "target". Each row defines a directed connection from one named node to another, following the direction of computation. The table must include edges from input feature nodes into the rest of the architecture graph.

required
backend str

Backend to compile to. One of "feedforward", "recurrent", or "graphnn".

"feedforward"
quiet bool

If False, emit informational notes during validation. If True, suppress informational notes.

False

Returns:

Type Description
tuple[Module, CompileArtifact]

Tuple (model, artifact). model is a PyTorch nn.Module compiled from the edgelist. artifact stores compilation metadata, including artifact.feature_names.

Raises:

Type Description
Edge2TorchError

If input validation, graph validation, or backend compilation fails.

Examples:

Compile a small feedforward architecture from an edgelist.

>>> import pandas as pd
>>> from edge2torch import compile_graph
>>>
>>> edgelist = pd.DataFrame(
...     {
...         "source": ["feature_a", "feature_b", "hidden_1"],
...         "target": ["hidden_1", "hidden_1", "prediction"],
...     }
... )
>>>
>>> model, artifact = compile_graph(
...     edgelist=edgelist,
...     backend="feedforward",
...     quiet=True,
... )
>>>
>>> artifact.feature_names
['feature_a', 'feature_b']

align_features_to_input_nodes

align_features_to_input_nodes(
    data, artifact: CompileArtifact
) -> torch.Tensor

Align data features to the input-node order expected by a compiled model.

compile_graph() builds a sparse neural network from an edgelist. Input nodes are inferred from the graph structure and stored in artifact.feature_names. These names define the required column order for tensors passed to the compiled PyTorch model.

For named data containers, this function validates exact feature-name compatibility and reorders features by name:

  • pandas.DataFrame inputs are aligned using column names.
  • AnnData inputs are aligned using var_names if anndata is installed.

Named data containers must contain exactly the compiled model input-node features, although they may appear in any order. Missing or extra features raise an error.

torch.Tensor inputs do not contain feature names, so they are only validated by shape and are assumed to already follow artifact.feature_names order.

Parameters:

Name Type Description Default
data DataFrame | Tensor | AnnData

Input data to align. AnnData is supported when anndata is installed.

required
artifact CompileArtifact

Compilation artifact returned by compile_graph(). Its feature_names field defines the required input-node order.

required

Returns:

Type Description
Tensor

Float32 input tensor whose columns are ordered according to artifact.feature_names.

Raises:

Type Description
Edge2TorchError

If the input data type is unsupported, required features are missing, extra features are present in named data containers, non-numeric DataFrame columns are present, or tensor input has an incompatible shape.

Examples:

Align a DataFrame whose columns are named but not ordered like the compiled model input nodes.

>>> import pandas as pd
>>> import torch
>>> from edge2torch import align_features_to_input_nodes, compile_graph
>>>
>>> edgelist = pd.DataFrame(
...     {
...         "source": ["feature_a", "feature_b", "hidden"],
...         "target": ["hidden", "hidden", "prediction"],
...     }
... )
>>> model, artifact = compile_graph(edgelist, quiet=True)
>>>
>>> data = pd.DataFrame(
...     {
...         "feature_b": [2.0, 4.0],
...         "feature_a": [1.0, 3.0],
...     }
... )
>>>
>>> artifact.feature_names
['feature_a', 'feature_b']
>>>
>>> x = align_features_to_input_nodes(
...     data=data,
...     artifact=artifact,
... )
>>> x
tensor([[1., 2.],
        [3., 4.]])

Tensor inputs do not contain feature names, so they are only checked by shape and are assumed to already follow artifact.feature_names.

>>> x_tensor = torch.tensor(
...     [
...         [1.0, 2.0],
...         [3.0, 4.0],
...     ]
... )
>>> x_from_tensor = align_features_to_input_nodes(
...     data=x_tensor,
...     artifact=artifact,
... )
>>> torch.equal(x_from_tensor, x_tensor)
True

customize_model

customize_model(
    model: Module,
    activation: Module | None = None,
    dropout: float | int | None = None,
    head: Module | None = None,
) -> nn.Module

Wrap a compiled sparse neural network with optional PyTorch modules.

This function is a convenience layer for common post-compilation additions. It applies the requested components sequentially to the output of the compiled model. It does not modify the sparse graph structure, insert modules inside graph-derived layers, or replace ordinary PyTorch training and customization.

Parameters:

Name Type Description Default
model Module

PyTorch model returned by compile_graph().

required
activation Module | None

Optional PyTorch activation module applied after the compiled model. This should be an instantiated module such as nn.ReLU().

None
dropout float | int | None

Optional dropout probability applied after the activation. Must satisfy 0 <= dropout < 1.

None
head Module | None

Optional PyTorch module applied after dropout. This should be an instantiated module such as nn.Linear(...).

None

Returns:

Type Description
Module

Wrapped PyTorch model with the requested post-compilation modules.

Raises:

Type Description
Edge2TorchError

If any input is invalid.

Examples:

Add an activation function after the compiled sparse neural network.

>>> import pandas as pd
>>> from torch import nn
>>> from edge2torch import compile_graph, customize_model
>>>
>>> edgelist = pd.DataFrame(
...     {
...         "source": ["feature_a", "feature_b", "hidden"],
...         "target": ["hidden", "hidden", "prediction"],
...     }
... )
>>> model, artifact = compile_graph(edgelist, quiet=True)
>>>
>>> customized_model = customize_model(
...     model=model,
...     activation=nn.ReLU(),
... )

Add an activation, dropout, and task-specific prediction head.

>>> customized_model = customize_model(
...     model=model,
...     activation=nn.ReLU(),
...     dropout=0.2,
...     head=nn.Linear(1, 1),
... )

interpret_model

interpret_model(
    model: Any,
    artifact: Any,
    data: Any,
    target: str = "features",
    method: str = "IntegratedGradients",
    constructor_kwargs: dict[str, Any] | None = None,
    attribute_kwargs: dict[str, Any] | None = None,
    quiet: bool = False,
) -> Union[pd.DataFrame, dict[str, pd.DataFrame]]

Interpret a model compiled by edge2torch using a Captum attribution method.

Parameters:

Name Type Description Default
model Any

PyTorch model returned by compile_graph(), optionally customized and trained by the user.

required
artifact Any

Compilation artifact returned by compile_graph().

required
data DataFrame | AnnData | Tensor

Input data used for attribution.

required
target str

Interpretation target. Use "features" to attribute predictions to input features. Use "nodes" to attribute predictions to named internal nodes of a feedforward compiled model.

"features"
method str

Captum attribution method name. Method names follow Captum class names exactly and are case-sensitive, for example "IntegratedGradients", "Saliency", "DeepLift", "LayerConductance", or "LayerIntegratedGradients".

The selected method must be compatible with target and the compiled backend. If an unsupported method is provided, edge2torch raises an error listing the supported method names.

"IntegratedGradients"
constructor_kwargs dict[str, Any] | None

Optional keyword arguments passed directly to the constructor of the selected Captum attribution class. These arguments are passed through unchanged and are not interpreted, validated, or modified by edge2torch. Refer to the Captum documentation for the selected method to determine which constructor arguments are supported.

None
attribute_kwargs dict[str, Any] | None

Optional keyword arguments passed directly to the selected Captum method's attribute() call. These arguments are passed through unchanged and are not interpreted, validated, or modified by edge2torch. Refer to the Captum documentation for the selected method to determine which attribution arguments are supported.

None
quiet bool

If False, emit informational notes. If True, suppress informational notes.

False

Returns:

Type Description
DataFrame | dict[str, DataFrame]

If target="features", returns one DataFrame with rows as examples and columns as input feature names.

If target="nodes", returns a dictionary mapping layer names to DataFrames. Each DataFrame has rows as examples and columns as named nodes for that layer.

Notes

Feature interpretation is currently supported for all implemented backends.

Node interpretation is currently supported only for the "feedforward" backend. Node interpretation methods use Captum layer attribution classes.

interpret_model() temporarily switches the model to evaluation mode while computing attributions and restores the previous training/evaluation mode afterward.

constructor_kwargs and attribute_kwargs are passed through to Captum. Refer to the Captum documentation for method-specific arguments such as baselines, targets, additional forward arguments, or perturbation settings.

Raises:

Type Description
Edge2TorchError

If interpretation input validation fails, the requested target / method / backend combination is not supported, or Captum returns unsupported output.

Examples:

Compute feature-level attributions with integrated gradients.

>>> feature_attributions = interpret_model(
...     model=trained_model,
...     artifact=artifact,
...     data=data,
...     target="features",
...     method="IntegratedGradients",
...     quiet=True,
... )
>>> feature_attributions.head()

Compute node-level attributions for a feedforward model.

>>> node_attributions = interpret_model(
...     model=trained_model,
...     artifact=artifact,
...     data=data,
...     target="nodes",
...     method="LayerConductance",
...     quiet=True,
... )
>>> node_attributions.keys()