-
Notifications
You must be signed in to change notification settings - Fork 21
model
desperadoccy edited this page Jan 15, 2025
·
4 revisions
model支持多样化加载,可以加载torch或者自己实现的模型。以如下方式配置:
"model": {
"path": "torchvision.models.resnet.resnet18",
"params": {
"num_classes": 10
}
},其中,params会通过eval函数执行自定义,使用方式如:
"model": {
"path": "model.HAR.cnn.CNN",
"params": {
"train_shape": "src_obj.train_ds.data.shape",
"category": 12
}
},也可以通过自定义函数构建函数创建模型,如下:
"model": {
"custom_create_fn": "lib.FedProto.FedProto.create_proto_model",
"params": {
"dataset": "mnist",
"mode": "model_heter",
"num_channels": 1,
"num_classes": 10,
"out_channels": 20
}
},模型加载也支持加载预训练模型,加载逻辑如下:
pretrained = config.get("pretrained", False)
pretrained_path = config.get("pretrained_path", "")
if pretrained or pretrained_path:
if pretrained:
# 尝试使用模型类的预训练参数(适用于 PyTorch 内置模型)
if hasattr(model_class, 'from_pretrained'):
# 对于支持 from_pretrained 的模型(如 HuggingFace 模型)
model = model_class.from_pretrained(path, **params)
else:
# 对于 PyTorch 提供的模型
model = model_class(pretrained=True, **params)
if pretrained_path:
state_dict = load_pretrained_weights(pretrained_path)
model.load_state_dict(state_dict)CNN模型
CNN+剪枝模型
LeNet5模型
和HAR相关的一些模型 该文件夹来自仓库,作者已授权。
Getting Started - 整体流程 - Module Guide - 现有算法 - Contact Us