import numpy as np
import matplotlib.pyplot as plt

# Parameters
N = 100  # Number of Nodes
xm = 100
ym = 100
zm = 100
Depth_threshold = 10
Eo = 10  # Initial Energy in joules
a = 1
Et = 0
f = 30  # kHz
thorp_atten = 0.11 * (f**2 / (1 + f**2)) + 44 * (f**2 / (4100 + f)) + ((2.75 * 10**-4) * f**2) + 0.003
bw = 30000  # Hz
P_t = 2  # watt
P_r = 0.1  # watt
E_elec = 50 * (10**-9)  # j/bit (transmitter electronics energy)
E_amp = 100 * (10**-12)  # j/bit/m2 (Amplifier energy)
EDA = 5 * 10**(-9)  # Data aggregation
l = 1000  # data packets
datarate = 250000  # bps
lh = 48  # hello packet size
ETX = (P_t * (l / (thorp_atten * bw)))
ERX = (P_r * (l / (thorp_atten * bw)))
ETXh = (P_t * (lh / (thorp_atten * bw)))
rmax = 1000  # Number of Rounds
TX_range = 25  # Transmission Range in meters
packets_send = 0
pkts_drpd = 0
Et_DBR = 0
offset = round(xm / 3)

rmax = 1000  # Number of Rounds
r_values = list(range(rmax + 1))

# Node class
class SensorNode:
    def __init__(self, x, y, depth, energy, node_id):
        self.x = x
        self.y = y
        self.depth = depth
        self.energy = energy
        self.node_id = node_id
        self.neighbors = []

# Sensor node deployment
sensorNodeInfo = []

for i in range(N):
    x = np.random.rand() * xm
    y = np.random.rand() * ym
    depth = np.random.rand() * zm
    node = SensorNode(x, y, depth, Eo, i)
    sensorNodeInfo.append(node)

# Divide the area into zones and regions
zone_width = xm / 4
region_height = ym / 4
zone_coordinates = [(i * zone_width, (i + 1) * zone_width) for i in range(4)]
region_coordinates = []
for zone_x in zone_coordinates:
    for i in range(4):
        region_coordinates.append((zone_x, (i * region_height, (i + 1) * region_height)))

# Sink positions
sinkNodeInfo = [(1, 12.5), (1, 37.5), (1, 62.5), (1, 87.5)]

# Initialize statistics
STATISTICS = {
    'Round': [],
    'Network_Energy': [],
    'packets_send': [],
    'PACKETS_DROPPED': [],
    'DEAD': [],
    'Alive_node_m': []
}

# Simulation
for r in range(rmax + 1):
    # Sink Mobility
    for i in range(len(sinkNodeInfo)):
        x, y = sinkNodeInfo[i]
        x = (x + 5) if x < 99 else 1
        sinkNodeInfo[i] = (x, y)  

    # Calculating total Energy
    total_energy = sum(node.energy for node in sensorNodeInfo if node.energy > 0)
    STATISTICS['Network_Energy'].append(total_energy)

    # Neighboring Finding
    for i in range(N):
        neighbors = []
        for j in range(N):
            if i != j and sensorNodeInfo[i].y < sensorNodeInfo[j].y:
                dis_btwn_nodes = np.sqrt((sensorNodeInfo[i].x - sensorNodeInfo[j].x)**2 + (sensorNodeInfo[i].y - sensorNodeInfo[j].y)**2)
                if dis_btwn_nodes <= TX_range:
                    neighbors.append(j)
        sensorNodeInfo[i].neighbors = neighbors

    # Packet sending
    for c in range(N):
        for i in range(len(sinkNodeInfo)):
            min_sink_distance = np.sqrt((sensorNodeInfo[c].x - sinkNodeInfo[i][0])**2 + (sensorNodeInfo[c].y - sinkNodeInfo[i][1])**2)
            if min_sink_distance <= TX_range and sensorNodeInfo[c].energy > 0:
                # The sink node is within transmission range and has sufficient energy
                packets_send += 1
                sensorNodeInfo[c].energy -= ETX

        if sensorNodeInfo[c].energy > 0:
            # Search for neighbor nodes in the proximity of sink nodes
            min_sink_index = -1
            min_sink_distance = TX_range
            for i in range(len(sinkNodeInfo)):
                distance = np.sqrt((sensorNodeInfo[c].x - sinkNodeInfo[i][0])**2 + (sensorNodeInfo[c].y - sinkNodeInfo[i][1])**2)
                if distance < min_sink_distance:
                    min_sink_distance = distance
                    min_sink_index = i

            if min_sink_index != -1:
                # Forward the packet to a neighbor in the proximity of a sink node
                packets_send += 1
                sensorNodeInfo[c].energy -= ETXh

        if sensorNodeInfo[c].energy <= 0:
            pkts_drpd += 1

    # Update statistics
    total_energy = sum(node.energy for node in sensorNodeInfo if node.energy > 0)
    STATISTICS['Network_Energy'].append(total_energy)
    STATISTICS['packets_send'].append(packets_send)
    STATISTICS['PACKETS_DROPPED'].append(pkts_drpd)

    dead = 0
    for node in sensorNodeInfo:
        if node.energy <= 0:
            dead += 1
    STATISTICS['DEAD'].append(dead)

    alive_node_m = N - dead
    STATISTICS['Alive_node_m'].append(alive_node_m)

    # Update statistics at each round
    STATISTICS['Round'].append(r)
    STATISTICS['Network_Energy'].append(total_energy)
    STATISTICS['packets_send'].append(packets_send)
    STATISTICS['PACKETS_DROPPED'].append(pkts_drpd)
    STATISTICS['DEAD'].append(dead)
    STATISTICS['Alive_node_m'].append(alive_node_m)

    # Display statistics periodically
    if r % 100 == 0:
        print(f"Round {r}:")
        print(f"Packets Sent: {packets_send}")
        print(f"Packets Dropped: {pkts_drpd}")
        print(f"Network Energy: {total_energy}")
        print(f"Dead Nodes: {dead}")
        print(f"Alive Nodes: {alive_node_m}")
        print()

print("Dead nodes:", dead)
print("Rounds:", r)
print("Packets dropped:", pkts_drpd)
print("Packets sent:", packets_send)


