Files
graph_recognition_w_attn/config_.py
2025-09-01 14:46:34 -04:00

30 lines
824 B
Python

import sys
from enum import Enum
class NoiseType(Enum):
NONE = "none"
NORMAL = "normal"
UNIFORM = "uniform"
class ModelConfig:
num_agents: int
embedding_dim: int = 64
input_dim: int = 1
output_dim: int = 1
class TrainConfig:
"""Configuration for the training process."""
learning_rate: float = 1e-3
epochs: int = 100
batch_size: int = 4096 * 4
verbose: bool = False # See loss during training
log: bool = True # Logs the attention every log interval
log_epoch_interval: int = 10 # How often you want to log intervals
noise_type: NoiseType = NoiseType.NONE # What kind of noise you want
noise_level: float = 0.00 # Stddev for Normal, half-width for Uniform
data_directory:str
f1_threshold: float = -0.5 # Threshold for binary classification