import json
import cv2
import trt_pose.coco
import os
import pickle
import trt_pose.models
import torch
from torch2trt import TRTModule
from trt_pose.draw_objects import DrawObjects
from trt_pose.parse_objects import ParseObjects
from preprocessdata import preprocessdata
from gesture_classifier import gesture_classifier
import torchvision.transforms as transforms
import PIL.Image
from jetcam.csi_camera import CSICamera
#from jetcam.usb_camera import USBCamera
import math
import serial
import time
import statistics

with open('preprocess/hand_pose.json', 'r') as f:
    hand_pose = json.load(f)

topology = trt_pose.coco.coco_category_to_topology(hand_pose)
num_parts = len(hand_pose['keypoints'])
num_links = len(hand_pose['skeleton'])

model = trt_pose.models.resnet18_baseline_att(num_parts, 2 * num_links).cuda().eval()

WIDTH = 224
HEIGHT = 224

if not os.path.exists('/home/articulight/Desktop/project/model/hand_pose_resnet18_att_244_244_trt.pth'):
    print("creating optimized model")
    MODEL_WEIGHTS = '/home/articulight/Desktop/project/model/hand_pose_resnet18_att_244_244.pth'
    model.load_state_dict(torch.load(MODEL_WEIGHTS))
    import torch2trt
    model_trt = torch2trt.torch2trt(model, [data], fp16_mode=True, max_workspace_size=1<<25)
    OPTIMIZED_MODEL = '/home/articulight/Desktop/project/model/hand_pose_resnet18_att_244_244_trt.pth'
    torch.save(model_trt.state_dict(), OPTIMIZED_MODEL)


OPTIMIZED_MODEL = '/home/articulight/Desktop/project/model/hand_pose_resnet18_att_244_244_trt.pth'

model_trt = TRTModule()
model_trt.load_state_dict(torch.load(OPTIMIZED_MODEL))

parse_objects = ParseObjects(topology,cmap_threshold=0.15, link_threshold=0.15)
draw_objects = DrawObjects(topology)

mean = torch.Tensor([0.485, 0.456, 0.406]).cuda()
std = torch.Tensor([0.229, 0.224, 0.225]).cuda()
device = torch.device('cuda')

def preprocess(image):
    global device
    device = torch.device('cuda')
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = PIL.Image.fromarray(image)
    image = transforms.functional.to_tensor(image).to(device)
    image.sub_(mean[:, None, None]).div_(std[:, None, None])
    return image[None, ...]

preprocessdata = preprocessdata(topology, num_parts)

gesture_classifier = gesture_classifier()

filename = 'large.sav'
clf = pickle.load(open(filename, 'rb'))

with open('/home/articulight/Desktop/project/preprocess/gesture.json', 'r') as f: # might need to be made in gesture collection file
    gesture = json.load(f)
gesture_type = gesture['classes']

def draw_joints(image, joints):
    count = 0
    
    for i in joints:
        if i == [0,0]:
            count+=1

    if count >= 3:
        return 

    for i in joints:
        cv2.circle(image, (i[0],i[1]), 2, (0,0,255), 1)

    cv2.circle(image, (joints[0][0],joints[0][1]), 2, (255,0,255), 1)

    for i in hand_pose['skeleton']:
        if joints[i[0]-1][0]==0 or joints[i[1]-1][0] == 0:
            break
        cv2.line(image, (joints[i[0]-1][0],joints[i[0]-1][1]), (joints[i[1]-1][0],joints[i[1]-1][1]), (0,255,0), 1)

#camera = USBCamera(width=224, height=224, capture_width=640, capture_height=480, capture_fps = 8, capture_device=1)
camera = CSICamera(width=WIDTH, height=HEIGHT, capture_fps=8)
'''camera = cv2.VideoCapture(0)
camera.set(cv2.CAP_PROP_FPS, 8)
camera.set(3, 224)
camera.set(4, 224)'''


serial_port = serial.Serial(
    port = "/dev/ttyUSB0", # ACM0 -> usb-usb, USB0 -> usb-pins
    baudrate = 9600,
    bytesize = serial.EIGHTBITS,
    parity = serial.PARITY_NONE,
    stopbits = serial.STOPBITS_ONE
)

time.sleep(1)

#sends message containing servo coordinates to atmega
def servo(directions):
    message = 'x' + str(directions[0]) + 'y' + str(directions[1])
    #send message to atmega328
    try:
        #print('sending servo information: ' + str(message))
        serial_port.write(message.encode())
    except Exception as e:
        print(e)

"""
#might be used for mapping coordinates to servo lengths at cost of accuracy but gaining computation efficiency
def calculate_center(position):
    #returns offset of supplied position and center of camera feed
    return [math.abs((WIDTH/2) - position[0]), math.abs((HEIGHT/2) - position[1])]
"""

# returns[x,y] values for signifying directions that the motors need to move
# 0 -> negative direction, 1-> no movement, 2 -> positive direction
def get_directions(position):
    error = 25 # error in pixels, creates square dead zone in center so that the motors do not over compensate
    ret = []

    if(abs(position[0] - 112) <= error):
        ret.append(1)
    elif(abs(position[0] - 112) > error):
        if(position[0] > 112):
            ret.append(2)
        else:
            ret.append(0)

    if(abs(position[1] - 112) <= error):
        ret.append(1)
    elif(abs(position[1] - 112) > error):
        if(position[1] > 112):
            ret.append(0)
        else:
            ret.append(2)
    
    return ret

error = 25
previous_gesture = ''
current_gesture = ''
follow_me = False
go_here = False
loop_times = [[] for _ in range(10)] # for calculating computation time metric
time_index = 0
gather_metrics = True

while(1):
    try:
        start_time = time.time()
        #gesture detection
        image = camera.read()
        image = cv2.rotate(image, cv2.ROTATE_180)
        #success, image = camera.read()
        data = preprocess(image)
        cmap, paf = model_trt(data)
        cmap, paf = cmap.detach().cpu(), paf.detach().cpu()
        counts, objects, peaks = parse_objects(cmap, paf)
        joints = preprocessdata.joints_inference(image, counts, objects, peaks)
        draw_joints(image, joints)
        dist_bn_joints = preprocessdata.find_distance(joints)
        gesture = clf.predict([dist_bn_joints, [0]*num_parts*num_parts])
        gesture_joints = gesture[0]
        #print("index joint: " + str(joints[6][:]))
        preprocessdata.prev_queue.append(gesture_joints)
        preprocessdata.prev_queue.pop(0)
        preprocessdata.print_label(image, preprocessdata.prev_queue, gesture_type)
        
        #logic for handling gestures
        current_gesture = preprocessdata.get_label(gesture_type)
        position = joints[6][:] # position of index finger in frame

        if(follow_me):
            if(current_gesture == 'three'):
                #send servo data to center the point
                 directions = get_directions(position)
                 servo(directions)
            elif((current_gesture == 'ok/rock on 2') and (previous_gesture != 'ok/rock on 2')):
                follow_me = False
                print('follow me deactivated')
        elif(go_here):
            if(current_gesture == 'thumbs up'):
                directions = get_directions(position)
                if((abs(position[0] - 112) <= error) and (abs(position[1] - 112) <= error)):
                    go_here = False
                    print('go_here deactivated')
                else:
                    servo(directions)
        else:
            if(current_gesture == 'ok/rock on 2'):
                follow_me = True
                print('follow me activated')
            elif(current_gesture == 'thumbs up' and previous_gesture != 'thumbs up'): # need to have position be from thumb instead -> joints[1][:]
                go_here = True
                print('go_here activated')

        cv2.imshow('camera feed', image)
        cv2.waitKey(1)
        previous_gesture = current_gesture

        #metrics for presentation
        if(gather_metrics):
            if(len(loop_times[9]) == 100):
                gather_metrics = False
            else:
                loop_times[time_index].append(time.time() - start_time)
                if(len(loop_times[time_index]) == 100):
                    if(time_index < 9):
                        time_index += 1
                        print('time_index updated to %i' % time_index)


    except KeyboardInterrupt:
        break

mean_list = []
for i in range(len(loop_times)):
    mean_list.append(statistics.mean(loop_times[i])) 
    print('mean %s: %s' % (str(i) , str(mean_list[i])))
print('total mean time taken between each iteration of main loop: ' + str(statistics.mean(mean_list)) + ' seconds')

serial_port.close()
