Files
graph_recognition_w_attn/config_.py

30 lines
824 B
Python
Raw Normal View History

2025-08-04 12:44:35 -04:00
import sys
from enum import Enum
class NoiseType(Enum):
NONE = "none"
NORMAL = "normal"
UNIFORM = "uniform"
class ModelConfig:
num_agents: int
2025-09-01 14:46:34 -04:00
embedding_dim: int = 64
2025-08-04 12:44:35 -04:00
input_dim: int = 1
output_dim: int = 1
class TrainConfig:
"""Configuration for the training process."""
learning_rate: float = 1e-3
epochs: int = 100
2025-09-01 14:46:34 -04:00
batch_size: int = 4096 * 4
verbose: bool = False # See loss during training
log: bool = True # Logs the attention every log interval
log_epoch_interval: int = 10 # How often you want to log intervals
noise_type: NoiseType = NoiseType.NONE # What kind of noise you want
noise_level: float = 0.00 # Stddev for Normal, half-width for Uniform
2025-08-04 12:44:35 -04:00
data_directory:str
2025-09-01 14:46:34 -04:00
f1_threshold: float = -0.5 # Threshold for binary classification
2025-08-04 12:44:35 -04:00