| import gradio as gr |
| import torch |
| import torch.nn as nn |
| from torchvision import transforms |
| from PIL import Image |
| import time |
| import os |
|
|
| from concrete.fhe import Configuration |
| from concrete.ml.torch.compile import compile_torch_model |
|
|
| from custom_resnet import resnet18_custom |
|
|
| |
| class_names = ['Fake', 'Real'] |
|
|
| |
| def load_model(model_path, device): |
| print("load_model") |
| model = resnet18_custom(weights=None) |
| num_ftrs = model.fc.in_features |
| model.fc = nn.Linear(num_ftrs, len(class_names)) |
| model.load_state_dict(torch.load(model_path, map_location=device)) |
| model = model.to(device) |
| model.eval() |
| return model |
|
|
|
|
| def load_secure_model(model): |
| print("Compiling secure model...") |
| secure_model = compile_torch_model( |
| model.to("cpu"), |
| n_bits={"model_inputs": 4, "op_inputs": 3, "op_weights": 3, "model_outputs": 5}, |
| rounding_threshold_bits={"n_bits": 7, "method": "APPROXIMATE"}, |
| p_error=0.05, |
| configuration=Configuration(enable_tlu_fusing=True, print_tlu_fusing=False, use_gpu=False), |
| torch_inputset=torch.rand(10, 3, 224, 224) |
| ) |
| return secure_model |
|
|
| |
| model = load_model('models/deepfake_detection_model.pth', 'cpu') |
| secure_model = load_secure_model(model) |
|
|
| |
| data_transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| ]) |
|
|
| |
| def predict(image, mode, expected_output=None): |
| device = 'cpu' |
|
|
| |
| image = Image.open(image).convert('RGB') |
| image = data_transform(image).unsqueeze(0).to(device) |
|
|
| |
| with torch.no_grad(): |
| start_time = time.time() |
| |
| if mode == "Fast": |
| |
| outputs = model(image) |
| elif mode == "Secure": |
| |
| detached_input = image.detach().numpy() |
| outputs = torch.from_numpy(secure_model.forward(detached_input, fhe="simulate")) |
| |
| _, preds = torch.max(outputs, 1) |
| elapsed_time = time.time() - start_time |
|
|
| predicted_class = class_names[preds[0]] |
| |
| |
| expected_output_message = f"Expected: {expected_output}" if expected_output else "Expected: Not Provided" |
| predicted_output_message = f"Predicted: {predicted_class}" |
| |
| return predicted_output_message, expected_output_message, f"Time taken: {elapsed_time:.2f} seconds" |
|
|
|
|
| |
| example_images = [ |
| ["./data/fake/fake_1.jpeg", "Fake", "Fast"], |
| ["./data/real/real_1.jpg", "Real", "Fast"], |
| ] |
|
|
| |
| iface = gr.Interface( |
| fn=predict, |
| inputs=[ |
| gr.Image(type="filepath", label="Upload an Image"), |
| gr.Radio(choices=["Fast", "Secure"], label="Inference Mode", value="Fast"), |
| gr.Textbox(label="Expected Output", value=None, placeholder="Optional: Enter expected output (Fake/Real)") |
| ], |
| outputs=[ |
| gr.Textbox(label="Prediction"), |
| gr.Textbox(label="Expected Output"), |
| gr.Textbox(label="Time Taken") |
| ], |
| examples=[ |
| ["./data/fake/fake_1.jpeg", "Fast", "Fake"], |
| ["./data/real/real_1.jpg", "Fast", "Real"], |
| ], |
| title="Deepfake Detection Model", |
| description="Upload an image or select a sample and choose the inference mode (Fast or Secure). Compare the predicted result with the expected output." |
| ) |
|
|
| if __name__ == "__main__": |
| iface.launch(share=True) |