1   
2   
3   
4   
5   
6   
7   
8   
9   
10  
11  
12  
13  
14  
15  
16  
17  package org.springframework.web.socket.server.standard;
18  
19  import java.util.ArrayList;
20  import java.util.HashMap;
21  import java.util.List;
22  import java.util.Map;
23  import javax.websocket.Decoder;
24  import javax.websocket.Encoder;
25  import javax.websocket.Endpoint;
26  import javax.websocket.Extension;
27  import javax.websocket.HandshakeResponse;
28  import javax.websocket.server.HandshakeRequest;
29  import javax.websocket.server.ServerEndpointConfig;
30  
31  import org.springframework.beans.factory.BeanFactory;
32  import org.springframework.beans.factory.BeanFactoryAware;
33  import org.springframework.util.Assert;
34  import org.springframework.web.socket.handler.BeanCreatingHandlerProvider;
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  public class ServerEndpointRegistration extends ServerEndpointConfig.Configurator
55  		implements ServerEndpointConfig, BeanFactoryAware {
56  
57  	private final String path;
58  
59  	private final BeanCreatingHandlerProvider<Endpoint> endpointProvider;
60  
61  	private final Endpoint endpoint;
62  
63      private List<Class<? extends Encoder>> encoders = new ArrayList<Class<? extends Encoder>>();
64  
65      private List<Class<? extends Decoder>> decoders = new ArrayList<Class<? extends Decoder>>();
66  
67  	private List<String> protocols = new ArrayList<String>();
68  
69  	private List<Extension> extensions = new ArrayList<Extension>();
70  
71  	private final Map<String, Object> userProperties = new HashMap<String, Object>();
72  
73  
74  	
75  
76  
77  
78  
79  
80  	public ServerEndpointRegistration(String path, Class<? extends Endpoint> endpointClass) {
81  		Assert.hasText(path, "path must not be empty");
82  		Assert.notNull(endpointClass, "endpointClass must not be null");
83  		this.path = path;
84  		this.endpointProvider = new BeanCreatingHandlerProvider<Endpoint>(endpointClass);
85  		this.endpoint = null;
86  	}
87  
88  	
89  
90  
91  
92  
93  
94  	public ServerEndpointRegistration(String path, Endpoint endpoint) {
95  		Assert.hasText(path, "path must not be empty");
96  		Assert.notNull(endpoint, "endpoint must not be null");
97  		this.path = path;
98  		this.endpointProvider = null;
99  		this.endpoint = endpoint;
100 	}
101 
102 
103 	@Override
104 	public String getPath() {
105 		return this.path;
106 	}
107 
108 	@Override
109 	public Class<? extends Endpoint> getEndpointClass() {
110 		return (this.endpoint != null) ?
111 				this.endpoint.getClass() : ((Class<? extends Endpoint>) this.endpointProvider.getHandlerType());
112 	}
113 
114 	public Endpoint getEndpoint() {
115 		return (this.endpoint != null) ? this.endpoint : this.endpointProvider.getHandler();
116 	}
117 
118 	public void setSubprotocols(List<String> protocols) {
119 		this.protocols = protocols;
120 	}
121 
122 	@Override
123 	public List<String> getSubprotocols() {
124 		return this.protocols;
125 	}
126 
127 	public void setExtensions(List<Extension> extensions) {
128 		this.extensions = extensions;
129 	}
130 
131 	@Override
132 	public List<Extension> getExtensions() {
133 		return this.extensions;
134 	}
135 
136 	public void setUserProperties(Map<String, Object> userProperties) {
137 		this.userProperties.clear();
138 		this.userProperties.putAll(userProperties);
139 	}
140 
141 	@Override
142 	public Map<String, Object> getUserProperties() {
143 		return this.userProperties;
144 	}
145 
146 	public void setEncoders(List<Class<? extends Encoder>> encoders) {
147 		this.encoders = encoders;
148 	}
149 
150 	@Override
151 	public List<Class<? extends Encoder>> getEncoders() {
152 		return this.encoders;
153 	}
154 
155 	public void setDecoders(List<Class<? extends Decoder>> decoders) {
156 		this.decoders = decoders;
157 	}
158 
159 	@Override
160 	public List<Class<? extends Decoder>> getDecoders() {
161 		return this.decoders;
162 	}
163 
164 	@Override
165 	public Configurator getConfigurator() {
166 		return this;
167 	}
168 
169 	@Override
170 	public void setBeanFactory(BeanFactory beanFactory) {
171 		if (this.endpointProvider != null) {
172 			this.endpointProvider.setBeanFactory(beanFactory);
173 		}
174 	}
175 
176 
177 	
178 
179 	@SuppressWarnings("unchecked")
180 	@Override
181 	public final <T> T getEndpointInstance(Class<T> clazz) throws InstantiationException {
182 		return (T) getEndpoint();
183 	}
184 
185 	@Override
186 	public void modifyHandshake(ServerEndpointConfig sec, HandshakeRequest request, HandshakeResponse response) {
187 		super.modifyHandshake(this, request, response);
188 	}
189 
190 	@Override
191 	public boolean checkOrigin(String originHeaderValue) {
192 		return super.checkOrigin(originHeaderValue);
193 	}
194 
195 	@Override
196 	public String getNegotiatedSubprotocol(List<String> supported, List<String> requested) {
197 		return super.getNegotiatedSubprotocol(supported, requested);
198 	}
199 
200 	@Override
201 	public List<Extension> getNegotiatedExtensions(List<Extension> installed, List<Extension> requested) {
202 		return super.getNegotiatedExtensions(installed, requested);
203 	}
204 
205 
206 	@Override
207 	public String toString() {
208 		return "ServerEndpointRegistration for path '" + getPath() + "': " + getEndpointClass();
209 	}
210 }