Skip to content
desperadoccy edited this page Jan 15, 2025 · 4 revisions

model参数介绍

model支持多样化加载,可以加载torch或者自己实现的模型。以如下方式配置:

"model": {
    "path": "torchvision.models.resnet.resnet18",
    "params": {
        "num_classes": 10
    }
},

其中,params会通过eval函数执行自定义,使用方式如:

"model": {
    "path": "model.HAR.cnn.CNN",
    "params": {
    "train_shape": "src_obj.train_ds.data.shape",
    "category": 12
    }
},

也可以通过自定义函数构建函数创建模型,如下:

"model": {
    "custom_create_fn": "lib.FedProto.FedProto.create_proto_model",
    "params": {
        "dataset": "mnist",
        "mode": "model_heter",
        "num_channels": 1,
        "num_classes": 10,
        "out_channels": 20
    }
},

模型加载也支持加载预训练模型,加载逻辑如下:

pretrained = config.get("pretrained", False)
pretrained_path = config.get("pretrained_path", "")
if pretrained or pretrained_path:
    if pretrained:
        # 尝试使用模型类的预训练参数(适用于 PyTorch 内置模型)
        if hasattr(model_class, 'from_pretrained'):
            # 对于支持 from_pretrained 的模型(如 HuggingFace 模型)
            model = model_class.from_pretrained(path, **params)
        else:
            # 对于 PyTorch 提供的模型
            model = model_class(pretrained=True, **params)

    if pretrained_path:
        state_dict = load_pretrained_weights(pretrained_path)
        model.load_state_dict(state_dict)

现有类介绍

CNN

CNN模型

CNN_pruning

CNN+剪枝模型

LeNet5

LeNet5模型

HAR相关

和HAR相关的一些模型 该文件夹来自仓库,作者已授权。

Clone this wiki locally