Skip to content

Commit 912f7ef

Browse files
committed
deploy model example
1 parent a070479 commit 912f7ef

File tree

8 files changed

+151
-1
lines changed

8 files changed

+151
-1
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
All notable changes to `slashml-python-client` aka `slashml` will be documented in this file.
33
This project adheres to [Semantic Versioning](https://semver.org/).
44

5+
## 0.1.3 - 2023-06-09
6+
7+
### Added
8+
- Model deployment
9+
510
## 0.1.3 - 2023-05-13
611

712
### Added

README.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,50 @@ print(f"Summary = {response.summarization_data}")
6565
```
6666

6767

68+
#### Deploy your own Model
69+
<!-- write a code snippet in the minimum number of lines -->
70+
71+
```python
72+
from slashml import ModelDeployment
73+
import time
74+
75+
# you might have to install transfomers and torch
76+
from transformers import pipeline
77+
78+
def train_model():
79+
# Bring in model from huggingface
80+
return pipeline('fill-mask', model='bert-base-uncased')
81+
82+
my_model = train_model()
83+
84+
# Replace `API_KEY` with your SlasML API token.
85+
API_KEY = "YOUR_API_KEY"
86+
87+
model = ModelDeployment(api_key=API_KEY)
88+
89+
# deploy model
90+
response = model.deploy(model_name='my_model_3', model=my_model)
91+
92+
# wait for it to be deployed
93+
time.sleep(2)
94+
status = model.status(model_version_id=response.id)
95+
96+
while status.status != 'READY':
97+
print(f'status: {status.status}')
98+
print('trying again in 5 seconds')
99+
time.sleep(5)
100+
status = model.status(model_version_id=response.id)
101+
102+
if status.status == 'FAILED':
103+
raise Exception('Model deployment failed')
104+
105+
# submit prediction
106+
input_text = 'Steve jobs is the [MASK] of Apple.'
107+
prediction = model.predict(model_version_id=response.id, model_input=input_text)
108+
print(prediction)
109+
```
110+
111+
68112
### View the list of service providers available
69113
```python
70114
from slashml import TextToSpeech
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from slashml import ModelDeployment
2+
import time
3+
4+
# you might have to install transfomers and torch
5+
from transformers import pipeline
6+
7+
def train_model():
8+
# Bring in model from huggingface
9+
return pipeline('fill-mask', model='bert-base-uncased')
10+
11+
my_model = train_model()
12+
13+
# Replace `API_KEY` with your SlasML API token.
14+
API_KEY = "YOUR_API_KEY"
15+
16+
model = ModelDeployment(api_key=API_KEY)
17+
18+
# deploy model
19+
response = model.deploy(model_name='my_model_3', model=my_model)
20+
21+
# wait for it to be deployed
22+
time.sleep(2)
23+
status = model.status(model_version_id=response.id)
24+
25+
while status.status != 'READY':
26+
print(f'status: {status.status}')
27+
print('trying again in 5 seconds')
28+
time.sleep(5)
29+
status = model.status(model_version_id=response.id)
30+
31+
if status.status == 'FAILED':
32+
raise Exception('Model deployment failed')
33+
34+
# submit prediction
35+
input_text = 'Steve jobs is the [MASK] of Apple.'
36+
prediction = model.predict(model_version_id=response.id, model_input=input_text)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
transformers==4.30.0
2+
torch==2.0.1

examples/text_to_speech_sync.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@
1515
# Submit request
1616
job = model.execute(text=input_text, service_provider=service_provider)
1717

18-
print (f"\n\n\n You can access the audio file here: {job.audio_url}")
18+
print(f"\n\n\n You can access the audio file here: {job.audio_url}")

requires-install.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
requests==2.28.1
22
addict==2.4.0
3+
truss==0.4.8

slashml/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33
from slashml.text_summarization import TextSummarization # noqa: F401,E402
44
from slashml.speech_to_text import SpeechToText # noqa: F401,E402
55
from slashml.text_to_speech import TextToSpeech # noqa: F401,E402
6+
from slashml.model_deployment import ModelDeployment # noqa: F401,E402
67

78
__all__ = ["TextSummarization", "SpeechToText", "TextToSpeech"]

slashml/model_deployment.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import json
2+
import requests
3+
import time
4+
import tarfile
5+
import truss
6+
from enum import Enum
7+
from .utils import generateURL, baseUrl, generateHeaders, formatResponse, getTaskStatus
8+
9+
10+
import os
11+
12+
13+
class ModelDeployment:
14+
_base_url = baseUrl("model-deployment", "v1")
15+
_headers = None
16+
17+
def __init__(self, api_key: str = None):
18+
if api_key==None:
19+
raise Exception("API Key is required for model deployment")
20+
self._headers = generateHeaders(api_key)
21+
22+
def create_tar_gz(self, *, folder_path, tar_gz_filename):
23+
with tarfile.open(tar_gz_filename, "w:gz") as tar:
24+
tar.add(folder_path, arcname=os.path.basename(folder_path))
25+
26+
def deploy(self, *, model_name:str, model: str):
27+
"""Submit job"""
28+
truss.create(model, 'my_model')
29+
self.create_tar_gz(folder_path='my_model', tar_gz_filename='my_model.tar.gz')
30+
31+
url = generateURL(self._base_url, "models")
32+
files = [("model_file", ("my_model.tar.gz", open('my_model.tar.gz', "rb"), "application/octet-stream"))]
33+
34+
payload = {
35+
"model_name": model_name,
36+
}
37+
38+
import pdb
39+
pdb.set_trace()
40+
41+
response = requests.post(url, headers=self._headers, data=payload, files=files)
42+
43+
return formatResponse(response)
44+
45+
def status(self, *, model_version_id: str):
46+
"""Check job status"""
47+
url = generateURL(self._base_url, "models", model_version_id, "status")
48+
response = requests.get(url, headers=self._headers)
49+
return formatResponse(response)
50+
51+
def predict(self, model_version_id: str, model_input:str):
52+
"""Check job status"""
53+
54+
payload = json.dumps({
55+
"model_input": model_input
56+
})
57+
58+
url = generateURL(self._base_url, "models", model_version_id, "predict")
59+
self._headers['Content-Type'] = 'application/json'
60+
response = requests.post(url, headers=self._headers, data=payload)
61+
return formatResponse(response)

0 commit comments

Comments
 (0)