# #!./.mnist-keras/bin/python
# import json
# import os

# import docker
# import fire
# import numpy as np
# import tensorflow as tf

# from fedn.utils.helpers import get_helper, save_metadata, save_metrics

# HELPER_MODULE = 'kerashelper'
# NUM_CLASSES = 10


# def _get_data_path():
#     # Figure out FEDn client number from container name
#     client = docker.from_env()
#     container = client.containers.get(os.environ['HOSTNAME'])
#     number = container.name[-1]
#     # Return data path
#     return f"/var/data/clients/{number}/mnist.npz"


# def _compile_model(img_rows=28, img_cols=28):
#     # Set input shape
#     input_shape = (img_rows, img_cols, 1)

#     # Define model
#     model = tf.keras.models.Sequential()
#     model.add(tf.keras.layers.Flatten(input_shape=input_shape))
#     model.add(tf.keras.layers.Dense(64, activation='relu'))
#     model.add(tf.keras.layers.Dropout(0.5))
#     model.add(tf.keras.layers.Dense(32, activation='relu'))
#     model.add(tf.keras.layers.Dense(NUM_CLASSES, activation='softmax'))
#     model.compile(loss=tf.keras.losses.categorical_crossentropy,
#                   optimizer=tf.keras.optimizers.Adam(),
#                   metrics=['accuracy'])
#     return model


# def _load_data(data_path, is_train=True):
#     # Load data
#     if data_path is None:
#         data = np.load(_get_data_path())
#     else:
#         data = np.load(data_path)

#     if is_train:
#         X = data['x_train']
#         y = data['y_train']
#     else:
#         X = data['x_test']
#         y = data['y_test']

#     # Normalize
#     X = X.astype('float32')
#     X = np.expand_dims(X, -1)
#     X = X / 255
#     y = tf.keras.utils.to_categorical(y, NUM_CLASSES)

#     return X, y


# def init_seed(out_path='seed.npz'):
#     weights = _compile_model().get_weights()
#     helper = get_helper(HELPER_MODULE)
#     helper.save(weights, out_path)


# def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1):
#     # Load data
#     x_train, y_train = _load_data(data_path)

#     # Load model
#     model = _compile_model()
#     helper = get_helper(HELPER_MODULE)
#     weights = helper.load(in_model_path)
#     model.set_weights(weights)

#     # Train
#     model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs)

#     # Save
#     weights = model.get_weights()
#     helper.save(weights, out_model_path)

#     # Metadata needed for aggregation server side
#     metadata = {
#         'num_examples': len(x_train),
#         'batch_size': batch_size,
#         'epochs': epochs,
#     }

#     # Save JSON metadata file
#     save_metadata(metadata, out_model_path)


# def validate(in_model_path, out_json_path, data_path=None):
#     # Load data
#     x_train, y_train = _load_data(data_path)
#     x_test, y_test = _load_data(data_path, is_train=False)

#     # Load model
#     model = _compile_model()
#     helper = get_helper(HELPER_MODULE)
#     weights = helper.load(in_model_path)
#     model.set_weights(weights)

#     # Evaluate
#     model_score = model.evaluate(x_train, y_train)
#     model_score_test = model.evaluate(x_test, y_test)
#     y_pred = model.predict(x_test)
#     y_pred = np.argmax(y_pred, axis=1)

#     # JSON schema
#     report = {
#         "training_loss": model_score[0],
#         "training_accuracy": model_score[1],
#         "test_loss": model_score_test[0],
#         "test_accuracy": model_score_test[1],
#     }

#     # Save JSON
#     save_metrics(report, out_json_path)


# def infer(in_model_path, out_json_path, data_path=None):
#     # Using test data for inference but another dataset could be loaded
#     x_test, _ = _load_data(data_path, is_train=False)

#     # Load model
#     model = _compile_model()
#     helper = get_helper(HELPER_MODULE)
#     weights = helper.load(in_model_path)
#     model.set_weights(weights)

#     # Infer
#     y_pred = model.predict(x_test)
#     y_pred = np.argmax(y_pred, axis=1)

#     # Save JSON
#     with open(out_json_path, "w") as fh:
#         fh.write(json.dumps({'predictions': y_pred.tolist()}))


# if __name__ == '__main__':
#     fire.Fire({
#         'init_seed': init_seed,
#         'train': train,
#         'validate': validate,
#         'infer': infer,
#         '_get_data_path': _get_data_path,  # for testing
#     })


# Updated entrypoint.py to use CIFAR-10 data and FedAvg algorithm

#!./.mnist-keras/bin/python
import json
import os
import docker
import fire
import numpy as np
import tensorflow as tf
from fedn.utils.helpers import get_helper, save_metadata, save_metrics

HELPER_MODULE = 'kerashelper'
NUM_CLASSES = 10  # CIFAR-10 has 10 classes

def _get_data_path():
    # Figure out FEDn client number from container name
    client = docker.from_env()
    container = client.containers.get(os.environ['HOSTNAME'])
    number = container.name[-1]
    # Return data path
    return f"/var/data/clients/{number}/cifar10.npz"

def _compile_model(img_rows=32, img_cols=32, img_channels=3):
    # Set input shape for CIFAR-10
    input_shape = (img_rows, img_cols, img_channels)

    # Define model
    model = tf.keras.models.Sequential()
    model.add(tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape))
    model.add(tf.keras.layers.MaxPooling2D((2, 2)))
    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(64, activation='relu'))
    model.add(tf.keras.layers.Dropout(0.5))
    model.add(tf.keras.layers.Dense(32, activation='relu'))
    model.add(tf.keras.layers.Dense(NUM_CLASSES, activation='softmax'))
    model.compile(loss=tf.keras.losses.categorical_crossentropy,
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=['accuracy'])
    return model

def _load_data(data_path, is_train=True):
    # Load data
    if data_path is None:
        data = np.load(_get_data_path())
    else:
        data = np.load(data_path)

    if is_train:
        X = data['x_train']
        y = data['y_train']
    else:
        X = data['x_test']
        y = data['y_test']

    # Normalize
    X = X.astype('float32')
    X = X / 255
    y = tf.keras.utils.to_categorical(y, NUM_CLASSES)

    return X, y

def _aggregate_models(models):
    # Implement Federated Averaging (FedAvg) algorithm for model aggregation
    avg_weights = [np.zeros_like(w) for w in models[0]]

    for model_weights in models:
        for i, layer_weights in enumerate(model_weights):
            avg_weights[i] += layer_weights

    avg_weights = [w / len(models) for w in avg_weights]

    return avg_weights

def _perform_fedavg(model_path):
    # Collect models from clients
    client_models = []

    for client_number in range(1, num_clients + 1):  # Update num_clients accordingly
        client_model_path = f'/var/data/clients/{client_number}/{model_path}'
        helper = get_helper(HELPER_MODULE)
        client_weights = helper.load(client_model_path)
        client_models.append(client_weights)

    # Aggregate models using FedAvg
    aggregated_weights = _aggregate_models(client_models)

    return aggregated_weights

def init_seed(out_path='seed.npz'):
    weights = _compile_model().get_weights()
    helper = get_helper(HELPER_MODULE)
    helper.save(weights, out_path)

def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1):
    # Load data
    x_train, y_train = _load_data(data_path)

    # Load model
    model = _compile_model()
    helper = get_helper(HELPER_MODULE)
    weights = helper.load(in_model_path)
    model.set_weights(weights)

    # Train
    model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs)

    # Save
    weights = model.get_weights()
    helper.save(weights, out_model_path)

    # Metadata needed for aggregation server side
    metadata = {
        'num_examples': len(x_train),
        'batch_size': batch_size,
        'epochs': epochs,
    }

    # Save JSON metadata file
    save_metadata(metadata, out_model_path)

    # Communicate with the aggregation server and perform FedAvg
    aggregated_weights = _perform_fedavg(out_model_path)

    # Save the aggregated weights
    helper.save(aggregated_weights, out_model_path)

def validate(in_model_path, out_json_path, data_path=None):
    # Load data
    x_train, y_train = _load_data(data_path)
    x_test, y_test = _load_data(data_path, is_train=False)

    # Load model
    model = _compile_model()
    helper = get_helper(HELPER_MODULE)
    weights = helper.load(in_model_path)
    model.set_weights(weights)

    # Evaluate
    model_score = model.evaluate(x_train, y_train)
    model_score_test = model.evaluate(x_test, y_test)
    y_pred = model.predict(x_test)
    y_pred = np.argmax(y_pred, axis=1)

    # JSON schema
    report = {
        "training_loss": model_score[0],
        "training_accuracy": model_score[1],
        "test_loss": model_score_test[0],
        "test_accuracy": model_score_test[1],
    }

    # Save JSON
    save_metrics(report, out_json_path)

def infer(in_model_path, out_json_path, data_path=None):
    # Using test data for inference but another dataset could be loaded
    x_test, _ = _load_data(data_path, is_train=False)

    # Load model
    model = _compile_model()
    helper = get_helper(HELPER_MODULE)
    weights = helper.load(in_model_path)
    model.set_weights(weights)

    # Infer
    y_pred = model.predict(x_test)
    y_pred = np.argmax(y_pred, axis=1)

    # Save JSON
    with open(out_json_path, "w") as fh:
        fh.write(json.dumps({'predictions': y_pred.tolist()}))

if __name__ == '__main__':
    num_clients = 2  # Update with the actual number of clients
    fire.Fire({
        'init_seed': init_seed,
        'train': train,
        'validate': validate,
        'infer': infer,
        '_get_data_path': _get_data_path,  # for testing
    })
