import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Set random seed for reproducibility
np.random.seed(42)

# Read data from CSV file
data = pd.read_csv('linear_regression_data.csv')
x = data['x'].values
y = data['y'].values

# Normalize the data for better training
x_mean, x_std = np.mean(x), np.std(x)
y_mean, y_std = np.mean(y), np.std(y)

x_norm = (x - x_mean) / x_std
y_norm = (y - y_mean) / y_std

class SimpleLinearNN:
    """
    A simple neural network with one linear layer for linear regression.
    This is essentially a PINN (Physics-Informed Neural Network) approach
    where the physics is linear relationship: y = wx + b
    """
    def __init__(self, learning_rate=0.01):
        # Initialize weights randomly
        self.w = np.random.randn() * 0.01  # slope (weight)
        self.b = np.random.randn() * 0.01  # intercept (bias)
        self.lr = learning_rate
        
    def forward(self, x):
        """Forward pass: y = wx + b (linear activation)"""
        return self.w * x + self.b
    
    def compute_loss(self, y_pred, y_true):
        """Mean Squared Error loss function"""
        return np.mean((y_pred - y_true) ** 2)
    
    def backward(self, x, y_pred, y_true):
        """Backward pass: compute gradients using chain rule"""
        n = len(x)
        error = y_pred - y_true
        
        # Compute gradients
        # dL/dw = (2/n) * sum(error * x)
        # dL/db = (2/n) * sum(error)
        dw = (2/n) * np.sum(error * x)
        db = (2/n) * np.sum(error)
        
        return dw, db
    
    def update_weights(self, dw, db):
        """Update weights using gradient descent"""
        self.w -= self.lr * dw
        self.b -= self.lr * db
    
    def train(self, x, y, epochs=1000):
        """Training loop"""
        losses = []
        
        print("Training neural network for linear regression...")
        for epoch in range(epochs):
            # Forward pass
            y_pred = self.forward(x)
            
            # Compute loss
            loss = self.compute_loss(y_pred, y)
            losses.append(loss)
            
            # Backward pass
            dw, db = self.backward(x, y_pred, y)
            
            # Update weights
            self.update_weights(dw, db)
            
            # Print progress
            if (epoch + 1) % 200 == 0:
                print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss:.6f}')
        
        return losses

# Create and train the neural network
nn = SimpleLinearNN(learning_rate=0.1)
losses = nn.train(x_norm, y_norm, epochs=1000)

# Get predictions on normalized data
y_pred_norm = nn.forward(x_norm)

# Denormalize predictions to original scale
y_pred = y_pred_norm * y_std + y_mean

# Convert learned parameters back to original scale
weight_original = nn.w * y_std / x_std
bias_original = nn.b * y_std + y_mean - nn.w * y_std * x_mean / x_std

print(f"\nLearned parameters:")
print(f"Slope (m): {weight_original:.6f}")
print(f"Intercept (c): {bias_original:.6f}")

# Plot results
plt.figure(figsize=(12, 5))

# Plot 1: Data and regression line
plt.subplot(1, 2, 1)
plt.scatter(x, y, alpha=0.6, label='Data points')
plt.plot(x, y_pred, color='red', linewidth=2, label=f'NN Fit: y = {weight_original:.2f}x + {bias_original:.2f}')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Linear Regression using Neural Network')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot 2: Training loss
plt.subplot(1, 2, 2)
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss (MSE)')
plt.title('Training Loss')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Calculate R-squared for goodness of fit
ss_res = np.sum((y - y_pred) ** 2)
ss_tot = np.sum((y - np.mean(y)) ** 2)
r_squared = 1 - (ss_res / ss_tot)

print(f"R-squared: {r_squared:.6f}")
print(f"Final training loss: {losses[-1]:.6f}")

# Show the neural network architecture
print(f"\nNeural Network Architecture:")
print(f"Input layer: x (1 neuron)")
print(f"Output layer: y = {nn.w:.6f} * x + {nn.b:.6f} (1 neuron)")
print(f"Learned parameters: weight = {nn.w:.6f}, bias = {nn.b:.6f}")
print(f"Activation function: Linear")
print(f"Optimizer: Gradient Descent")
print(f"Loss function: Mean Squared Error")