topologicpy.ANN module
- class topologicpy.ANN.ANN
Bases:
objectANN is a TopologicPy-style helper class for regular tabular ML using PyTorch.
The class mirrors the high-level API of
topologicpy.PyG.PyGbut targets non-graph datasets (rows in a CSV file).Methods
ByCSVPath(path[, task, labelHeader, ...])Create an ANN instance by reading a dataset from a CSV file path.
LoadModel(path[, strict, ...])Load model weights from disk (backward compatible).
PlotConfusionMatrix([split, normalize, ...])Plot a confusion matrix for classification using TopologicPy's Plotly helper.
PlotHistory([title])Plot training history curves (losses) using Plotly.
PlotParity([split, title, show_identity, ...])Plot a parity plot (true vs predicted) for regression.
Predict([path, return_proba, return_logits, ...])Predict outputs for the loaded dataset or a new CSV dataset.
SaveModel(path[, include_config])Save model weights (and optionally config) to disk.
SetHyperparameters(**kwargs)Update training/model hyperparameters.
Summary()Return a concise string summary of the ANN configuration and dataset shape.
Test()Evaluate the model on the test split.
Train()Train the ANN model.
Validate()Evaluate the model on the validation split.
- static ByCSVPath(path: str, task: str = 'classification', labelHeader: str = 'label', featHeader: str = 'feat', idHeader: str = 'id', featuresKeys: Optional[List[str]] = None, useMasksIfPresent: bool = True, split: Tuple[float, float, float] = (0.8, 0.1, 0.1), device: str = 'auto', randomState: int = 42, shuffle: bool = True)
Create an ANN instance by reading a dataset from a CSV file path.
- Parameters
- pathstr
Path to a CSV file.
- taskstr, optional
“classification” or “regression”. Default is “classification”.
- labelHeaderstr, optional
Label column name. Default is “label”.
- featHeaderstr, optional
Feature column name (list-like string). Default is “feat”.
- idHeaderstr, optional
Optional ID column name. Default is “id”.
- featuresKeyslist, optional
If provided, features will be built from these explicit columns instead of parsing featHeader.
- useMasksIfPresentbool, optional
If True and mask columns exist, use them. Default is True.
- splittuple(float,float,float), optional
Train/val/test ratio used when masks are not provided. Default is (0.8,0.1,0.1).
- devicestr, optional
“auto”, “cpu”, or “cuda”. Default is “auto”.
- randomStateint, optional
RNG seed for splitting and training determinism. Default is 42.
- shufflebool, optional
Whether to shuffle prior to ratio split. Default is True.
- Returns
- ANN
A configured ANN instance with data loaded and splits created.
- LoadModel(path: str, strict: bool = True, rebuild_from_checkpoint: bool = True) None
Load model weights from disk (backward compatible).
- Parameters
- pathstr
Path to a
.ptcheckpoint file.- strictbool, optional
Passed to
load_state_dict. Default is True.- rebuild_from_checkpointbool, optional
If True and the checkpoint includes config fields, rebuild the model before loading. Default is True.
- Returns
- None
- PlotConfusionMatrix(split: str = 'test', normalize: bool = False, minValue: float = None, maxValue: float = None, title: Optional[str] = None, xTitle: str = 'Actual Categories', yTitle: str = 'Predicted Categories', width: int = 950, height: int = 500, showScale: bool = True, colorScale: str = 'viridis', colorSamples: int = 10, backgroundColor: str = 'rgba(0,0,0,0)', marginLeft: int = 0, marginRight: int = 0, marginTop: int = 40, marginBottom: int = 0, baseFontSize: int = 16, tickFontSize: int = 14, titleFontSize: int = 22, axisTitleFontSize: int = 16, annotationFontSize: int = 18, grayScale: bool = False, mantissa: int = 6)
Plot a confusion matrix for classification using TopologicPy’s Plotly helper.
- Parameters
- splitstr , optional
Which split(s) to evaluate. Options are: {“train”,”val”,”validate”,”validation”,”test”,”all”}. Default is “test”.
- normalizebool, optional
If True, row-normalize the confusion matrix. Default is False.
- titlestr , optional
The desired title to display. Default is “Confusion Matrix”.
- xTitlestr , optional
The desired X-axis title to display. Default is “Actual Categories”.
- yTitlestr , optional
The desired Y-axis title to display. Default is “Predicted Categories”.
- minValuefloat , optional
The desired minimum value to use for the color scale. If set to None, the minimum value found in the input matrix will be used.
- maxValuefloat , optional
The desired maximum value to use for the color scale. If set to None, the maximum value found in the input matrix will be used.
- widthint , optional
The desired width of the figure. Default is 950.
- heightint , optional
The desired height of the figure. Default is 500.
- showScalebool , optional
If set to True, a color scale is shown on the right side of the figure. Default is True.
- colorScalestr , optional
The desired type of plotly color scales to use (e.g. “Viridis”, “Plasma”). Default is “Viridis”.
- colorSamplesint , optional
The number of discrete color samples to use for displaying the data. Default is 10.
- backgroundColorlist or str , optional
The desired background color (see docstring above). Default is transparent.
- marginLeft, marginRight, marginTop, marginBottomint , optional
Plot margins in pixels.
- baseFontSizeint , optional
The base font size. Default is 16.
- tickFontSizeint , optional
The tick font size. Default is 14.
- titleFontSizeint , optional
The title font size. Default is 22.
- axisTitleFontSizeint , optional
The axis title font size. Default is 16.
- annotationFontSizeint , optional
The annotation font size. Default is 18.
- grayScalebool , optional
If set to True, the figure is rendered in grayscale. Default is False.
- Returns
- plotly.graph_objects.Figure
Confusion matrix figure.
- PlotHistory(title: str = 'Learning Curves')
Plot training history curves (losses) using Plotly.
- Parameters
- titlestr, optional
Plot title. Default is “Learning Curves”.
- Returns
- plotly.graph_objects.Figure
A Plotly figure.
Notes
Requires Plotly.
- PlotParity(split: str = 'test', title: Optional[str] = None, show_identity: bool = True, show_best_fit: bool = True, point_size: int = 6)
Plot a parity plot (true vs predicted) for regression.
- Parameters
- split{“train”,”val”,”validate”,”validation”,”test”,”all”}, optional
Which split(s) to evaluate. Default is “test”.
- titlestr, optional
Custom title. If None, uses an automatic title.
- show_identitybool, optional
If True, plot y=x. Default is True.
- show_best_fitbool, optional
If True, plot least-squares fit line. Default is True.
- point_sizeint, optional
Marker size. Default is 6.
- Returns
- plotly.graph_objects.Figure
Parity scatter plot.
Notes
Requires plotly.
- Predict(path: Optional[str] = None, return_proba: bool = False, return_logits: bool = False, attach_to_df: bool = True) Dict[str, Any]
Predict outputs for the loaded dataset or a new CSV dataset.
- Parameters
- pathstr, optional
If provided, load a new CSV file and run inference on it. If None, predicts on the currently loaded dataset.
- return_probabool, optional
If True and task is classification, return class probabilities. Default is False.
- return_logitsbool, optional
If True and task is classification, return raw logits. Default is False.
- attach_to_dfbool, optional
If True, attach predictions back to
self.df(or newly loaded df) as columns: “pred” and optionally “proba_*”. Default is True.
- Returns
- dict
Prediction package with keys: - “pred” : numpy array of predictions (class indices or regression values) - “proba” : numpy array (N,C) if requested for classification - “logits”: numpy array (N,C) if requested for classification - “df” : pandas DataFrame if attach_to_df=True
- SaveModel(path: str, include_config: bool = True) None
Save model weights (and optionally config) to disk.
- Parameters
- pathstr
Output file path. If it does not end with
.pt, it is appended.- include_configbool, optional
If True, saves a checkpoint dict that includes config fields to support robust reload. Default is True.
- Returns
- None
- SetHyperparameters(**kwargs) None
Update training/model hyperparameters.
- Parameters
- **kwargsdict
Any attribute of the internal config. Common keys include: epochs, batch_size, lr, weight_decay, optimizer, hidden_dims, activation, dropout, batch_norm, early_stopping, early_stopping_patience, gradient_clip_norm, split, shuffle, random_state, device.
- Returns
- None
Notes
If a model-shaping parameter changes (e.g. hidden_dims, activation, batch_norm), the model is rebuilt automatically.
- Summary() str
Return a concise string summary of the ANN configuration and dataset shape.
- Returns
- str
Human-readable summary.
- Test() Dict[str, Any]
Evaluate the model on the test split.
- Returns
- dict
Metrics dictionary. For classification includes accuracy/f1_macro/etc. For regression includes mae/rmse/r2. Always includes y_true and y_pred.
- Train() Dict[str, List[float]]
Train the ANN model.
- Returns
- dict
Training history dictionary with per-epoch curves. Typical keys: “train_loss”, “val_loss” and task-specific metric keys.
- Validate() Dict[str, Any]
Evaluate the model on the validation split.
- Returns
- dict
Metrics dictionary. For classification includes accuracy/f1_macro/etc. For regression includes mae/rmse/r2. Always includes y_true and y_pred.