"""
articulight.py

Authors: NVIDIA, Benjamin Simms (bensimms@knights.ucf.edu)

This script is built off of trt_pose_hand (https://github.com/NVIDIA-AI-IOT/trt_pose_hand)
and controls the behavior of the Articulight system for summer 2022 group 9's senior design project 
in the ECE department at the University of Central Florida.

This script loads a model that detects and returns positions for each keypoint in a hand (digits in fingers, palm, etc)
and runs the result from that model into a support vector machine (see svm.py for model details) to classify the gesture.
Based on the gesture, a serial message is sent to an Atmega328 microcontroller on a PCB designed by the rest of the group
to control servo motors that control the articulight lighting system.  

Much of this code was developed by NVIDIA for their Jetson Nano educational projects.
Namely, the model loading and image preparation for feeding into the model was solely 
developed by NVIDIA and was translated from a .ipynb file to this one. Since a substantial 
amount of code in this script was developed by NVIDIA:
Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation
files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy,
modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
"""

import json
import cv2
import trt_pose.coco
import os
import pickle
import trt_pose.models
import torch
from torch2trt import TRTModule
from trt_pose.draw_objects import DrawObjects
from trt_pose.parse_objects import ParseObjects
from preprocessdata import preprocessdata
import torchvision.transforms as transforms
import PIL.Image
from jetcam.csi_camera import CSICamera
import serial
import time
import statistics

with open('preprocess/hand_pose.json', 'r') as f:
    hand_pose = json.load(f)

topology = trt_pose.coco.coco_category_to_topology(hand_pose)
num_parts = len(hand_pose['keypoints'])
num_links = len(hand_pose['skeleton'])

model = trt_pose.models.resnet18_baseline_att(num_parts, 2 * num_links).cuda().eval()

WIDTH = 224
HEIGHT = 224

if not os.path.exists('/home/articulight/Desktop/project/model/hand_pose_resnet18_att_244_244_trt.pth'):
    print("creating optimized model")
    MODEL_WEIGHTS = '/home/articulight/Desktop/project/model/hand_pose_resnet18_att_244_244.pth'
    model.load_state_dict(torch.load(MODEL_WEIGHTS))
    import torch2trt
    model_trt = torch2trt.torch2trt(model, [data], fp16_mode=True, max_workspace_size=1<<25)
    OPTIMIZED_MODEL = '/home/articulight/Desktop/project/model/hand_pose_resnet18_att_244_244_trt.pth'
    torch.save(model_trt.state_dict(), OPTIMIZED_MODEL)

OPTIMIZED_MODEL = '/home/articulight/Desktop/project/model/hand_pose_resnet18_att_244_244_trt.pth'

model_trt = TRTModule()
model_trt.load_state_dict(torch.load(OPTIMIZED_MODEL))

parse_objects = ParseObjects(topology,cmap_threshold=0.15, link_threshold=0.15)
draw_objects = DrawObjects(topology)

mean = torch.Tensor([0.485, 0.456, 0.406]).cuda()
std = torch.Tensor([0.229, 0.224, 0.225]).cuda()
device = torch.device('cuda')

def preprocess(image):
    global device
    device = torch.device('cuda')
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = PIL.Image.fromarray(image)
    image = transforms.functional.to_tensor(image).to(device)
    image.sub_(mean[:, None, None]).div_(std[:, None, None])
    return image[None, ...]

preprocessdata = preprocessdata(topology, num_parts)

filename = 'large.sav' # SVM model 
clf = pickle.load(open(filename, 'rb'))

with open('/home/articulight/Desktop/project/preprocess/gesture.json', 'r') as f:
    gesture = json.load(f)
gesture_type = gesture['classes']

def draw_joints(image, joints):
    count = 0
    
    for i in joints:
        if i == [0,0]:
            count+=1

    if count >= 3:
        return 

    for i in joints:
        cv2.circle(image, (i[0],i[1]), 2, (0,0,255), 1)

    cv2.circle(image, (joints[0][0],joints[0][1]), 2, (255,0,255), 1)

    for i in hand_pose['skeleton']:
        if joints[i[0]-1][0]==0 or joints[i[1]-1][0] == 0:
            break
        cv2.line(image, (joints[i[0]-1][0],joints[i[0]-1][1]), (joints[i[1]-1][0],joints[i[1]-1][1]), (0,255,0), 1)

camera = CSICamera(width=WIDTH, height=HEIGHT, capture_fps=8)

serial_port = serial.Serial(
    port = "/dev/ttyUSB0", # ACM0 -> usb-usb, USB0 -> usb-pins
    baudrate = 9600,
    bytesize = serial.EIGHTBITS,
    parity = serial.PARITY_NONE,
    stopbits = serial.STOPBITS_ONE
)

time.sleep(1) # allow serial to setup

#sends step message to atmega
def servo_step(directions):
    message = 'x' + str(directions[0]) + 'y' + str(directions[1])
    try:
        serial_port.write(message.encode())
    except Exception as e:
        print(e)

# returns[x,y] values for signifying directions that the motors need to move
# 0 -> negative direction, 1-> no movement, 2 -> positive direction
def get_directions(position):
    error = 25 # error in pixels, creates square dead zone in center (with side length 2*error) so that the motors do not over compensate
    ret = []

    if(abs(position[0] - 112) <= error):
        ret.append(1)
    elif(abs(position[0] - 112) > error):
        if(position[0] > 112):
            ret.append(2)
        else:
            ret.append(0)

    if(abs(position[1] - 112) <= error):
        ret.append(1)
    elif(abs(position[1] - 112) > error):
        if(position[1] > 112):
            ret.append(0)
        else:
            ret.append(2)
    
    return ret

current_gesture = ''
follow_me = False
follow_me_safeguard = False # for preventing activation and deactivation if gesture held > 2 seconds
follow_me_deactivate_safeguard = True # do same as above for deactivating
gesture_buffer = [] # keeps track of gestures
buffer_size = 8 # 8 corresponds to 1 second because camera samples 8 times per second.
loop_times = [[] for _ in range(10)] # for calculating computation time metric
time_index = 0 # metric variable
gather_metrics = True # whether to gather metrics
first_frame = True # fixes bug in metric logic

while(1):
    try:
        start_time = time.time()
        #gesture detection
        image = camera.read()
        image = cv2.rotate(image, cv2.ROTATE_180)
        data = preprocess(image)
        cmap, paf = model_trt(data)
        cmap, paf = cmap.detach().cpu(), paf.detach().cpu()
        counts, objects, peaks = parse_objects(cmap, paf)
        joints = preprocessdata.joints_inference(image, counts, objects, peaks)
        draw_joints(image, joints)
        dist_bn_joints = preprocessdata.find_distance(joints)
        gesture = clf.predict([dist_bn_joints, [0]*num_parts*num_parts])
        gesture_joints = gesture[0]
        preprocessdata.prev_queue.append(gesture_joints)
        preprocessdata.prev_queue.pop(0)
        preprocessdata.print_label(image, preprocessdata.prev_queue, gesture_type)
        
        #logic for handling gestures
        current_gesture = preprocessdata.get_label(gesture_type)

        #buffer keeps track of previous gestures with size of buffer_size.
        #useful for seeing if a gesture had been held for a period of (1/8)*buffer_size seconds if the buffer holds the same gesture in the entire list
        gesture_buffer.append(current_gesture)
        #print(gesture_buffer)
        if(len(gesture_buffer) > buffer_size):
            gesture_buffer.pop(0)

        index_position = joints[6][:] # position of index finger in frame
        thumb_position = joints[2][:] # position of thumb in frame

        if((current_gesture != 'Rock on') and (follow_me_deactivate_safeguard == False)):
                follow_me_deactivate_safeguard = True

        if(follow_me):
            if((current_gesture != 'Rock on') and (follow_me_safeguard == False)):
                follow_me_safeguard = True
            #send servo data to center the index finger
            #directions = get_directions(index_position)
            #servo_step(directions)
            servo_step(index_position)
            if(follow_me_safeguard):
                if(all(ges == 'Rock on' for ges in gesture_buffer) and (len(gesture_buffer) == buffer_size)):
                    follow_me = False
                    gesture_buffer.clear() # reset buffer
                    print('follow me deactivated')
        else:
            if(follow_me_deactivate_safeguard):
                if(all(ges == 'Rock on' for ges in gesture_buffer) and (len(gesture_buffer) == buffer_size)):
                    follow_me = True
                    follow_me_safeguard = False
                    gesture_buffer.clear() # reset buffer
                    print('follow me activated')
            if(all(ges == 'Thumbs up' for ges in gesture_buffer) and (len(gesture_buffer) == buffer_size)):
                message = 'w' + str(thumb_position[0]) + 'h' + str(thumb_position[1])
                print('go here activated')
                serial_port.write(message.encode())
                gesture_buffer.clear() # reset buffer
            if(all(ges == 'Three' for ges in gesture_buffer) and (len(gesture_buffer) == 8)):
                message = 'r'
                print('reset servos activated')
                gesture_buffer.clear() # reset buffer
                serial_port.write(message.encode())

        cv2.imshow('camera feed', image)
        cv2.waitKey(1)

        #metrics
        if(not first_frame): # do not gather timing information on first frame (takes a while to get going on first frame resulting in skewed data)
            if(gather_metrics):
                if(len(loop_times[9]) == 100):
                    gather_metrics = False
                else:
                    loop_times[time_index].append(time.time() - start_time)
                    if(len(loop_times[time_index]) == 100):
                        if(time_index < 9):
                            time_index += 1
                            #print('time_index updated to %i' % time_index)
        else:
            first_frame = False


    except KeyboardInterrupt:
        break

mean_list = []
for i in range(len(loop_times)):
    mean_list.append(statistics.mean(loop_times[i])) 
    print('mean %s: %s' % (str(i) , str(mean_list[i])))
print('total mean time taken between each iteration of main loop: ' + str(statistics.mean(mean_list)) + ' seconds')

serial_port.close()